1+ import base64
12import dataclasses
23import multiprocessing
34import re
@@ -137,6 +138,17 @@ def _clone_data(data):
137138 return data
138139
139140
141+ def wrap_check_implementation (data , submission_output ):
142+ # Old version returned just a single string, new version
143+ # returns (bool, str); this function ensures compatibility with old
144+ # problem definitions.
145+ result = check_implementation (data , submission_output )
146+ if isinstance (result , tuple ):
147+ return result
148+ else :
149+ return not bool (result ), result
150+
151+
140152def _run_single_test (test : TestCase ):
141153 """
142154 Runs a single test case. Do not call directly
@@ -146,7 +158,7 @@ def _run_single_test(test: TestCase):
146158 torch .cuda .synchronize ()
147159 submission_output = custom_kernel (_clone_data (data ))
148160 torch .cuda .synchronize ()
149- return check_implementation (data , submission_output )
161+ return wrap_check_implementation (data , submission_output )
150162
151163
152164def run_single_test (pool : multiprocessing .Pool , test : TestCase ):
@@ -168,13 +180,15 @@ def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[T
168180 logger .log ("test-count" , len (tests ))
169181 for idx , test in enumerate (tests ):
170182 logger .log (f"test.{ idx } .spec" , test .spec )
171- error = run_single_test (pool , test )
172- if error :
183+ good , message = run_single_test (pool , test )
184+ if not good :
173185 logger .log (f"test.{ idx } .status" , "fail" )
174- logger .log (f"test.{ idx } .error" , error )
186+ logger .log (f"test.{ idx } .error" , message )
175187 passed = False
176188 else :
177189 logger .log (f"test.{ idx } .status" , "pass" )
190+ if message :
191+ logger .log (f"test.{ idx } .message" , message )
178192
179193 if passed :
180194 logger .log ("check" , "pass" )
@@ -196,9 +210,9 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
196210 check_copy = _clone_data (data )
197211 # first, one obligatory correctness check
198212 output = custom_kernel (data )
199- error = check_implementation (check_copy , output )
200- if error :
201- return error
213+ good , message = wrap_check_implementation (check_copy , output )
214+ if not good :
215+ return message
202216
203217 # now, do multiple timing runs without further correctness testing
204218 # there is an upper bound of 100 runs, and a lower bound of 3 runs;
@@ -220,16 +234,16 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
220234 end = time .perf_counter_ns ()
221235
222236 if recheck :
223- error = check_implementation (check_copy , output )
224- if error :
225- return error
237+ good , message = check_implementation (check_copy , output )
238+ if not good :
239+ return message
226240
227241 del output
228242 durations .append (end - start )
229243
230244 if i > 1 :
231245 stats = calculate_stats (durations )
232- if stats .err / stats .mean < 0.01 or stats .mean * stats .runs > max_time_ns :
246+ if stats .err / stats .mean < 0.001 or stats .mean * stats .runs > max_time_ns :
233247 break
234248
235249 return calculate_stats (durations )
@@ -282,6 +296,31 @@ def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: l
282296 return 112
283297
284298
299+ def run_single_profile (test : TestCase ) -> str :
300+ """
301+ Runs a single test case. Do not call directly
302+ """
303+ from submission import custom_kernel
304+ from torch .profiler import profile , record_function , ProfilerActivity
305+ data = generate_input (** test .args )
306+ torch .cuda .synchronize ()
307+
308+ with profile (activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]) as prof :
309+ submission_output = custom_kernel (_clone_data (data ))
310+ torch .cuda .synchronize ()
311+ return prof .key_averages ().table (sort_by = "self_cuda_time_total" , row_limit = 20 )
312+
313+
314+ def run_profiling (logger : PopcornOutput , tests : list [TestCase ]):
315+ logger .log ("benchmark-count" , len (tests ))
316+ for idx , test in enumerate (tests ):
317+ logger .log (f"benchmark.{ idx } .spec" , test .spec )
318+ report = run_single_profile (test )
319+ logger .log (f"benchmark.{ idx } .report" , base64 .b64encode (report .encode ("utf-8" ), b"+*" ).decode ("utf-8" ))
320+ logger .log ("check" , "pass" )
321+ return 0
322+
323+
285324def main ():
286325 fd = os .getenv ("POPCORN_FD" )
287326 if not fd :
@@ -324,8 +363,10 @@ def main():
324363 break
325364
326365 logger .log ("check" , "pass" if passed else "fail" )
366+ elif mode == "profile" :
367+ run_profiling (logger , tests )
327368 else :
328- # TODO: Implement script and profile mode
369+ # TODO: Implement script mode
329370 return 2
330371
331372
0 commit comments