Skip to content

Commit 63fb91c

Browse files
committed
move build_task_config into task.py
1 parent df65573 commit 63fb91c

3 files changed

Lines changed: 73 additions & 78 deletions

File tree

src/discord-cluster-manager/cogs/submit_cog.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
make_short_report,
1919
)
2020
from run_eval import FullResult
21-
from task import LeaderboardTask
21+
from task import LeaderboardTask, build_task_config
2222
from utils import (
2323
KernelBotError,
24-
build_task_config,
2524
send_discord_message,
2625
setup_logging,
2726
with_error_handling,

src/discord-cluster-manager/task.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pathlib import Path
66
from typing import Dict, Optional, Union
77

8-
from consts import Language, RankCriterion
8+
from consts import Language, RankCriterion, SubmissionMode
99
from utils import KernelBotError
1010

1111

@@ -134,3 +134,73 @@ def make_task(yaml_file: str | Path) -> LeaderboardTask:
134134

135135
if __name__ == "__main__":
136136
print(json.dumps(make_task("task.yml").to_dict(), indent=4))
137+
138+
139+
def build_task_config(
140+
task: "LeaderboardTask" = None,
141+
submission_content: str = None,
142+
arch: str = None,
143+
mode: SubmissionMode = None,
144+
) -> dict:
145+
if task is None:
146+
assert mode == SubmissionMode.SCRIPT
147+
# TODO detect language
148+
lang = "py"
149+
150+
config = {
151+
"lang": lang,
152+
"arch": arch,
153+
}
154+
155+
eval_name = {"py": "eval.py", "cu": "eval.cu"}[lang]
156+
157+
if lang == "py":
158+
config["main"] = "eval.py"
159+
160+
return {
161+
**config,
162+
"sources": {
163+
eval_name: submission_content,
164+
},
165+
}
166+
else:
167+
all_files = {}
168+
for n, c in task.files.items():
169+
if c == "@SUBMISSION@":
170+
all_files[n] = submission_content
171+
else:
172+
all_files[n] = c
173+
174+
common = {
175+
"lang": task.lang.value,
176+
"arch": arch,
177+
"benchmarks": task.benchmarks,
178+
"tests": task.tests,
179+
"mode": mode.value,
180+
"test_timeout": task.test_timeout,
181+
"benchmark_timeout": task.benchmark_timeout,
182+
"ranked_timeout": task.ranked_timeout,
183+
"ranking_by": task.ranking_by.value,
184+
"seed": task.seed,
185+
}
186+
187+
if task.lang == Language.Python:
188+
return {
189+
"main": task.config.main,
190+
"sources": all_files,
191+
**common,
192+
}
193+
else:
194+
sources = {}
195+
headers = {}
196+
for f in all_files:
197+
if f in task.config.sources:
198+
sources[f] = all_files[f]
199+
else:
200+
headers[f] = all_files[f]
201+
202+
return {
203+
"sources": sources,
204+
"headers": headers,
205+
"include_dirs": task.config.include_dirs,
206+
}

src/discord-cluster-manager/utils.py

Lines changed: 1 addition & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import functools
22
import logging
33
import subprocess
4-
from typing import TYPE_CHECKING, Any, Optional
4+
from typing import Any, Optional
55

66
import discord
7-
from consts import Language, SubmissionMode
8-
9-
if TYPE_CHECKING:
10-
from task import LeaderboardTask
117

128

139
def setup_logging(name: Optional[str] = None):
@@ -178,76 +174,6 @@ def invalidate(self):
178174
self._q.clear()
179175

180176

181-
def build_task_config(
182-
task: "LeaderboardTask" = None,
183-
submission_content: str = None,
184-
arch: str = None,
185-
mode: SubmissionMode = None,
186-
) -> dict:
187-
if task is None:
188-
assert mode == SubmissionMode.SCRIPT
189-
# TODO detect language
190-
lang = "py"
191-
192-
config = {
193-
"lang": lang,
194-
"arch": arch,
195-
}
196-
197-
eval_name = {"py": "eval.py", "cu": "eval.cu"}[lang]
198-
199-
if lang == "py":
200-
config["main"] = "eval.py"
201-
202-
return {
203-
**config,
204-
"sources": {
205-
eval_name: submission_content,
206-
},
207-
}
208-
else:
209-
all_files = {}
210-
for n, c in task.files.items():
211-
if c == "@SUBMISSION@":
212-
all_files[n] = submission_content
213-
else:
214-
all_files[n] = c
215-
216-
common = {
217-
"lang": task.lang.value,
218-
"arch": arch,
219-
"benchmarks": task.benchmarks,
220-
"tests": task.tests,
221-
"mode": mode.value,
222-
"test_timeout": task.test_timeout,
223-
"benchmark_timeout": task.benchmark_timeout,
224-
"ranked_timeout": task.ranked_timeout,
225-
"ranking_by": task.ranking_by.value,
226-
"seed": task.seed,
227-
}
228-
229-
if task.lang == Language.Python:
230-
return {
231-
"main": task.config.main,
232-
"sources": all_files,
233-
**common,
234-
}
235-
else:
236-
sources = {}
237-
headers = {}
238-
for f in all_files:
239-
if f in task.config.sources:
240-
sources[f] = all_files[f]
241-
else:
242-
headers[f] = all_files[f]
243-
244-
return {
245-
"sources": sources,
246-
"headers": headers,
247-
"include_dirs": task.config.include_dirs,
248-
}
249-
250-
251177
def format_time(value: float | str, err: Optional[float | str] = None, scale=None): # noqa: C901
252178
if value is None:
253179
logging.warning("Expected a number, got None", stack_info=True)

0 commit comments

Comments
 (0)