Skip to content

Commit d71e9d8

Browse files
committed
REFACTOR: Load config with dataclasses, earlier validation
1 parent dfaf4ea commit d71e9d8

13 files changed

Lines changed: 332 additions & 227 deletions

File tree

bench_runner/bases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def get_bases() -> list[str]:
15-
return config.get_bench_runner_config().get("bases", {}).get("versions", [])
15+
return config.get_config().bases.versions
1616

1717

1818
@functools.cache

bench_runner/config.py

Lines changed: 111 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,130 @@
22
Handles the loading of the bench_runner.toml configuration file.
33
"""
44

5+
import dataclasses
56
import functools
67
from pathlib import Path
78
import tomllib
8-
from typing import Any
99

1010

11-
from . import runners
11+
from . import flags as mflags
12+
from . import plot as mplot
13+
from . import runners as mrunners
1214
from .util import PathLike
1315

1416

17+
@dataclasses.dataclass
18+
class Bases:
19+
# The base versions to compare every benchmark run to.
20+
# Should be a full-specified version, e.g. "3.13.0".
21+
versions: list[str]
22+
# List of configuration flags that are compared against the default build of
23+
# its commit merge base.
24+
compare_to_default: list[str] = dataclasses.field(default_factory=list)
25+
26+
def __post_init__(self):
27+
if len(self.versions) == 0:
28+
raise RuntimeError(
29+
"No `bases.versions` are defined in `bench_runner.toml`. "
30+
)
31+
mflags.normalize_flags(self.compare_to_default)
32+
33+
34+
@dataclasses.dataclass
35+
class Notify:
36+
# The Github issue to use to send notification emails
37+
notification_issue: int = 0
38+
39+
40+
@dataclasses.dataclass
41+
class PublishMirror:
42+
# Whether to skip publishing to the mirror
43+
skip: bool = False
44+
45+
46+
@dataclasses.dataclass
47+
class Benchmarks:
48+
# Benchmarks to exclude from plots.
49+
excluded_benchmarks: list[str] = dataclasses.field(default_factory=list)
50+
51+
52+
@dataclasses.dataclass
53+
class Weekly:
54+
flags: list[str] = dataclasses.field(default_factory=list)
55+
runners: list[str] = dataclasses.field(default_factory=list)
56+
57+
def __post_init__(self):
58+
self.flags = mflags.normalize_flags(self.flags)
59+
60+
61+
@dataclasses.dataclass
62+
class Config:
63+
bases: Bases
64+
runners: dict[str, mrunners.Runner]
65+
publish_mirror: PublishMirror = dataclasses.field(default_factory=PublishMirror)
66+
benchmarks: Benchmarks = dataclasses.field(default_factory=Benchmarks)
67+
notify: Notify = dataclasses.field(default_factory=Notify)
68+
longitudinal_plot: mplot.LongitudinalPlotConfig | None = None
69+
flag_effect_plot: mplot.FlagEffectPlotConfig | None = None
70+
benchmark_longitudinal_plot: mplot.BenchmarkLongitudinalPlotConfig | None = None
71+
weekly: dict[str, Weekly] = dataclasses.field(default_factory=dict)
72+
73+
def __post_init__(self):
74+
self.bases = Bases(**self.bases) # pyright: ignore[reportCallIssue]
75+
if len(self.runners) == 0:
76+
raise RuntimeError(
77+
"No runners are defined in `bench_runner.toml`. "
78+
"Please set up some runners first."
79+
)
80+
self.runners = {
81+
name: mrunners.Runner(
82+
nickname=name, **runner # pyright: ignore[reportCallIssue]
83+
)
84+
for name, runner in self.runners.items()
85+
}
86+
if isinstance(self.publish_mirror, dict):
87+
self.publish_mirror = PublishMirror(**self.publish_mirror)
88+
if isinstance(self.benchmarks, dict):
89+
self.benchmarks = Benchmarks(**self.benchmarks)
90+
if isinstance(self.notify, dict):
91+
self.notify = Notify(**self.notify)
92+
self.longitudinal_plot = (
93+
mplot.LongitudinalPlotConfig(
94+
**self.longitudinal_plot # pyright: ignore[reportCallIssue]
95+
)
96+
if self.longitudinal_plot
97+
else None
98+
)
99+
self.flag_effect_plot = (
100+
mplot.FlagEffectPlotConfig(
101+
**self.flag_effect_plot # pyright: ignore[reportCallIssue]
102+
)
103+
if self.flag_effect_plot
104+
else None
105+
)
106+
self.benchmark_longitudinal_plot = (
107+
mplot.BenchmarkLongitudinalPlotConfig(
108+
**self.benchmark_longitudinal_plot # pyright: ignore[reportCallIssue]
109+
)
110+
if self.benchmark_longitudinal_plot
111+
else None
112+
)
113+
self.weekly = {
114+
name: Weekly(**weekly) # pyright: ignore[reportCallIssue]
115+
for name, weekly in self.weekly.items()
116+
}
117+
if len(self.weekly) == 0:
118+
self.weekly = {"default": Weekly(runners=list(self.runners.keys()))}
119+
120+
15121
@functools.cache
16-
def get_bench_runner_config(filepath: PathLike | None = None):
122+
def get_config(filepath: PathLike | None = None) -> Config:
17123
if filepath is None:
18124
filepath = Path("bench_runner.toml")
19125
else:
20126
filepath = Path(filepath)
21127

22128
with filepath.open("rb") as fd:
23-
return tomllib.load(fd)
24-
129+
content = tomllib.load(fd)
25130

26-
def get_config_for_current_runner(filepath: PathLike | None = None) -> dict[str, Any]:
27-
config = get_bench_runner_config(filepath)
28-
runner = runners.get_runner_for_hostname(cfgpath=filepath)
29-
all_runners = config.get("runners", {})
30-
if len(all_runners) >= 1:
31-
return all_runners.get(runner.nickname, {})
32-
return {}
131+
return Config(**content)

bench_runner/flags.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,11 @@ def flags_to_human(flags: list[str]) -> Iterable[str]:
5656
if flag_descr.name == flag:
5757
yield flag_descr.short_name
5858
break
59+
60+
61+
def normalize_flags(flags: list[str]) -> list[str]:
62+
result = sorted(set(flags))
63+
for flag in result:
64+
if flag not in FLAG_MAPPING.values():
65+
raise ValueError(f"Invalid flag {flag}")
66+
return result

bench_runner/gh.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from . import config
1313
from . import flags as mflags
14-
from . import runners
1514

1615

1716
def get_machines():
18-
return [x.name for x in runners.get_runners() if x.available] + ["all"]
17+
cfg = config.get_config()
18+
return [x.name for x in cfg.runners.values() if x.available] + ["all"]
1919

2020

2121
def _get_flags(d: Mapping[str, Any]) -> list[str]:
@@ -71,8 +71,8 @@ def benchmark(
7171

7272

7373
def send_notification(body):
74-
conf = config.get_bench_runner_config()
75-
notification_issue = conf.get("notify", {}).get("notification_issue", 0)
74+
cfg = config.get_config()
75+
notification_issue = cfg.notify.notification_issue
7676

7777
if notification_issue == 0:
7878
print("Not sending Github notification.")

0 commit comments

Comments
 (0)