|
2 | 2 | Handles the loading of the bench_runner.toml configuration file. |
3 | 3 | """ |
4 | 4 |
|
| 5 | +import dataclasses |
5 | 6 | import functools |
6 | 7 | from pathlib import Path |
7 | 8 | import tomllib |
8 | | -from typing import Any |
9 | 9 |
|
10 | 10 |
|
11 | | -from . import runners |
| 11 | +from . import flags as mflags |
| 12 | +from . import plot as mplot |
| 13 | +from . import runners as mrunners |
12 | 14 | from .util import PathLike |
13 | 15 |
|
14 | 16 |
|
| 17 | +@dataclasses.dataclass |
| 18 | +class Bases: |
| 19 | + # The base versions to compare every benchmark run to. |
| 20 | + # Should be a full-specified version, e.g. "3.13.0". |
| 21 | + versions: list[str] |
| 22 | + # List of configuration flags that are compared against the default build of |
| 23 | + # its commit merge base. |
| 24 | + compare_to_default: list[str] = dataclasses.field(default_factory=list) |
| 25 | + |
| 26 | + def __post_init__(self): |
| 27 | + if len(self.versions) == 0: |
| 28 | + raise RuntimeError( |
| 29 | + "No `bases.versions` are defined in `bench_runner.toml`. " |
| 30 | + ) |
| 31 | + mflags.normalize_flags(self.compare_to_default) |
| 32 | + |
| 33 | + |
| 34 | +@dataclasses.dataclass |
| 35 | +class Notify: |
| 36 | + # The Github issue to use to send notification emails |
| 37 | + notification_issue: int = 0 |
| 38 | + |
| 39 | + |
| 40 | +@dataclasses.dataclass |
| 41 | +class PublishMirror: |
| 42 | + # Whether to skip publishing to the mirror |
| 43 | + skip: bool = False |
| 44 | + |
| 45 | + |
| 46 | +@dataclasses.dataclass |
| 47 | +class Benchmarks: |
| 48 | + # Benchmarks to exclude from plots. |
| 49 | + excluded_benchmarks: list[str] = dataclasses.field(default_factory=list) |
| 50 | + |
| 51 | + |
| 52 | +@dataclasses.dataclass |
| 53 | +class Weekly: |
| 54 | + flags: list[str] = dataclasses.field(default_factory=list) |
| 55 | + runners: list[str] = dataclasses.field(default_factory=list) |
| 56 | + |
| 57 | + def __post_init__(self): |
| 58 | + self.flags = mflags.normalize_flags(self.flags) |
| 59 | + |
| 60 | + |
| 61 | +@dataclasses.dataclass |
| 62 | +class Config: |
| 63 | + bases: Bases |
| 64 | + runners: dict[str, mrunners.Runner] |
| 65 | + publish_mirror: PublishMirror = dataclasses.field(default_factory=PublishMirror) |
| 66 | + benchmarks: Benchmarks = dataclasses.field(default_factory=Benchmarks) |
| 67 | + notify: Notify = dataclasses.field(default_factory=Notify) |
| 68 | + longitudinal_plot: mplot.LongitudinalPlotConfig | None = None |
| 69 | + flag_effect_plot: mplot.FlagEffectPlotConfig | None = None |
| 70 | + benchmark_longitudinal_plot: mplot.BenchmarkLongitudinalPlotConfig | None = None |
| 71 | + weekly: dict[str, Weekly] = dataclasses.field(default_factory=dict) |
| 72 | + |
| 73 | + def __post_init__(self): |
| 74 | + self.bases = Bases(**self.bases) # pyright: ignore[reportCallIssue] |
| 75 | + if len(self.runners) == 0: |
| 76 | + raise RuntimeError( |
| 77 | + "No runners are defined in `bench_runner.toml`. " |
| 78 | + "Please set up some runners first." |
| 79 | + ) |
| 80 | + self.runners = { |
| 81 | + name: mrunners.Runner( |
| 82 | + nickname=name, **runner # pyright: ignore[reportCallIssue] |
| 83 | + ) |
| 84 | + for name, runner in self.runners.items() |
| 85 | + } |
| 86 | + if isinstance(self.publish_mirror, dict): |
| 87 | + self.publish_mirror = PublishMirror(**self.publish_mirror) |
| 88 | + if isinstance(self.benchmarks, dict): |
| 89 | + self.benchmarks = Benchmarks(**self.benchmarks) |
| 90 | + if isinstance(self.notify, dict): |
| 91 | + self.notify = Notify(**self.notify) |
| 92 | + self.longitudinal_plot = ( |
| 93 | + mplot.LongitudinalPlotConfig( |
| 94 | + **self.longitudinal_plot # pyright: ignore[reportCallIssue] |
| 95 | + ) |
| 96 | + if self.longitudinal_plot |
| 97 | + else None |
| 98 | + ) |
| 99 | + self.flag_effect_plot = ( |
| 100 | + mplot.FlagEffectPlotConfig( |
| 101 | + **self.flag_effect_plot # pyright: ignore[reportCallIssue] |
| 102 | + ) |
| 103 | + if self.flag_effect_plot |
| 104 | + else None |
| 105 | + ) |
| 106 | + self.benchmark_longitudinal_plot = ( |
| 107 | + mplot.BenchmarkLongitudinalPlotConfig( |
| 108 | + **self.benchmark_longitudinal_plot # pyright: ignore[reportCallIssue] |
| 109 | + ) |
| 110 | + if self.benchmark_longitudinal_plot |
| 111 | + else None |
| 112 | + ) |
| 113 | + self.weekly = { |
| 114 | + name: Weekly(**weekly) # pyright: ignore[reportCallIssue] |
| 115 | + for name, weekly in self.weekly.items() |
| 116 | + } |
| 117 | + if len(self.weekly) == 0: |
| 118 | + self.weekly = {"default": Weekly(runners=list(self.runners.keys()))} |
| 119 | + |
| 120 | + |
15 | 121 | @functools.cache |
16 | | -def get_bench_runner_config(filepath: PathLike | None = None): |
| 122 | +def get_config(filepath: PathLike | None = None) -> Config: |
17 | 123 | if filepath is None: |
18 | 124 | filepath = Path("bench_runner.toml") |
19 | 125 | else: |
20 | 126 | filepath = Path(filepath) |
21 | 127 |
|
22 | 128 | with filepath.open("rb") as fd: |
23 | | - return tomllib.load(fd) |
24 | | - |
| 129 | + content = tomllib.load(fd) |
25 | 130 |
|
26 | | -def get_config_for_current_runner(filepath: PathLike | None = None) -> dict[str, Any]: |
27 | | - config = get_bench_runner_config(filepath) |
28 | | - runner = runners.get_runner_for_hostname(cfgpath=filepath) |
29 | | - all_runners = config.get("runners", {}) |
30 | | - if len(all_runners) >= 1: |
31 | | - return all_runners.get(runner.nickname, {}) |
32 | | - return {} |
| 131 | + return Config(**content) |
0 commit comments