Skip to content

Commit 7d355a6

Browse files
committed
pytorch-level profiler
1 parent 58dba8a commit 7d355a6

3 files changed

Lines changed: 124 additions & 12 deletions

File tree

examples/eval.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import dataclasses
23
import multiprocessing
34
import re
@@ -137,6 +138,17 @@ def _clone_data(data):
137138
return data
138139

139140

141+
def wrap_check_implementation(data, submission_output):
142+
# Old version returned just a single string, new version
143+
# returns (bool, str); this function ensures compatibility with old
144+
# problem definitions.
145+
result = check_implementation(data, submission_output)
146+
if isinstance(result, tuple):
147+
return result
148+
else:
149+
return not bool(result), result
150+
151+
140152
def _run_single_test(test: TestCase):
141153
"""
142154
Runs a single test case. Do not call directly
@@ -146,7 +158,7 @@ def _run_single_test(test: TestCase):
146158
torch.cuda.synchronize()
147159
submission_output = custom_kernel(_clone_data(data))
148160
torch.cuda.synchronize()
149-
return check_implementation(data, submission_output)
161+
return wrap_check_implementation(data, submission_output)
150162

151163

152164
def run_single_test(pool: multiprocessing.Pool, test: TestCase):
@@ -168,13 +180,15 @@ def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[T
168180
logger.log("test-count", len(tests))
169181
for idx, test in enumerate(tests):
170182
logger.log(f"test.{idx}.spec", test.spec)
171-
error = run_single_test(pool, test)
172-
if error:
183+
good, message = run_single_test(pool, test)
184+
if not good:
173185
logger.log(f"test.{idx}.status", "fail")
174-
logger.log(f"test.{idx}.error", error)
186+
logger.log(f"test.{idx}.error", message)
175187
passed = False
176188
else:
177189
logger.log(f"test.{idx}.status", "pass")
190+
if message:
191+
logger.log(f"test.{idx}.message", message)
178192

179193
if passed:
180194
logger.log("check", "pass")
@@ -196,9 +210,9 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
196210
check_copy = _clone_data(data)
197211
# first, one obligatory correctness check
198212
output = custom_kernel(data)
199-
error = check_implementation(check_copy, output)
200-
if error:
201-
return error
213+
good, message = wrap_check_implementation(check_copy, output)
214+
if not good:
215+
return message
202216

203217
# now, do multiple timing runs without further correctness testing
204218
# there is an upper bound of 100 runs, and a lower bound of 3 runs;
@@ -220,16 +234,16 @@ def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_t
220234
end = time.perf_counter_ns()
221235

222236
if recheck:
223-
error = check_implementation(check_copy, output)
224-
if error:
225-
return error
237+
good, message = check_implementation(check_copy, output)
238+
if not good:
239+
return message
226240

227241
del output
228242
durations.append(end-start)
229243

230244
if i > 1:
231245
stats = calculate_stats(durations)
232-
if stats.err / stats.mean < 0.01 or stats.mean * stats.runs > max_time_ns:
246+
if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns:
233247
break
234248

235249
return calculate_stats(durations)
@@ -282,6 +296,31 @@ def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: l
282296
return 112
283297

284298

299+
def run_single_profile(test: TestCase) -> str:
300+
"""
301+
Runs a single test case. Do not call directly
302+
"""
303+
from submission import custom_kernel
304+
from torch.profiler import profile, record_function, ProfilerActivity
305+
data = generate_input(**test.args)
306+
torch.cuda.synchronize()
307+
308+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
309+
submission_output = custom_kernel(_clone_data(data))
310+
torch.cuda.synchronize()
311+
return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20)
312+
313+
314+
def run_profiling(logger: PopcornOutput, tests: list[TestCase]):
315+
logger.log("benchmark-count", len(tests))
316+
for idx, test in enumerate(tests):
317+
logger.log(f"benchmark.{idx}.spec", test.spec)
318+
report = run_single_profile(test)
319+
logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"))
320+
logger.log("check", "pass")
321+
return 0
322+
323+
285324
def main():
286325
fd = os.getenv("POPCORN_FD")
287326
if not fd:
@@ -324,8 +363,10 @@ def main():
324363
break
325364

326365
logger.log("check", "pass" if passed else "fail")
366+
elif mode == "profile":
367+
run_profiling(logger, tests)
327368
else:
328-
# TODO: Implement script and profile mode
369+
# TODO: Implement script mode
329370
return 2
330371

331372

src/discord-cluster-manager/cogs/leaderboard_cog.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,25 @@ async def submit_bench(
279279
interaction, leaderboard_name, script, mode=SubmissionMode.BENCHMARK, gpu=gpu
280280
)
281281

282+
@app_commands.command(name="profile", description="Start a profiling run")
283+
@app_commands.describe(
284+
leaderboard_name="Name of the competition / kernel to optimize",
285+
script="The Python / CUDA script file to run",
286+
gpu="Select GPU. Leave empty for interactive or automatic selection.",
287+
)
288+
@app_commands.autocomplete(leaderboard_name=leaderboard_name_autocomplete)
289+
@with_error_handling
290+
async def submit_profile(
291+
self,
292+
interaction: discord.Interaction,
293+
script: discord.Attachment,
294+
leaderboard_name: Optional[str],
295+
gpu: Optional[str],
296+
):
297+
return await self.submit(
298+
interaction, leaderboard_name, script, mode=SubmissionMode.PROFILE, gpu=gpu
299+
)
300+
282301
@app_commands.command(
283302
name="ranked", description="Start a ranked run for an official leaderboard submission"
284303
)

src/discord-cluster-manager/report.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import base64
12
import dataclasses
3+
import textwrap
24
from typing import List
35

46
import consts
@@ -195,6 +197,17 @@ def make_short_report(runs: dict[str, EvalResult], full=True) -> list[str]: # n
195197
elif full:
196198
result.append("❌ Benchmarks missing")
197199

200+
if "profile" in runs:
201+
bench_run = runs["profile"].run
202+
if not bench_run.success:
203+
result.append("❌ Running profile failed" + _short_fail_reason(bench_run))
204+
return result
205+
elif not bench_run.passed:
206+
result.append("❌ Profiling failed")
207+
return result
208+
else:
209+
result.append("✅ Profiling successful")
210+
198211
if "leaderboard" in runs:
199212
lb_run = runs["leaderboard"].run
200213
if not lb_run.success:
@@ -263,6 +276,29 @@ def log_one(base_name):
263276
return "❗ Could not find any benchmarks"
264277

265278

279+
def make_profile_log(run: RunResult) -> str:
280+
num_bench = int(run.result.get("benchmark-count", 0))
281+
282+
def log_one(base_name):
283+
spec = run.result.get(f"{base_name}.spec")
284+
285+
report: str = run.result.get(f"{base_name}.report")
286+
report = base64.b64decode(report.encode("utf-8"), b"+*").decode("utf-8")
287+
report = textwrap.indent(report, " ")
288+
bench_log.append(f"{spec}\n")
289+
bench_log.append(report)
290+
291+
bench_log = []
292+
for i in range(num_bench):
293+
log_one(f"benchmark.{i}")
294+
bench_log.append("")
295+
296+
if len(bench_log) > 0:
297+
return "\n".join(bench_log)
298+
else:
299+
return "❗ Could not find any profiling data"
300+
301+
266302
def generate_system_info(system: SystemInfo):
267303
return f"""
268304
Running on:
@@ -314,6 +350,22 @@ def generate_report(result: FullResult) -> RunResultReport: # noqa: C901
314350
make_benchmark_log(bench_run),
315351
)
316352

353+
if "profile" in runs:
354+
prof_run = runs["profile"]
355+
if prof_run.compilation is not None and not prof_run.compilation.success:
356+
_generate_compile_report(report, prof_run.compilation)
357+
return report
358+
359+
prof_run = prof_run.run
360+
if not prof_run.success:
361+
_generate_crash_report(report, prof_run)
362+
return report
363+
364+
report.add_log(
365+
"Profiling",
366+
make_profile_log(prof_run),
367+
)
368+
317369
if "leaderboard" in runs:
318370
bench_run = runs["leaderboard"]
319371
if bench_run.compilation is not None and not bench_run.compilation.success:

0 commit comments

Comments
 (0)