Skip to content

Commit 1694935

Browse files
committed
Address comments from PR
1 parent 3d867c6 commit 1694935

5 files changed

Lines changed: 16 additions & 16 deletions

File tree

bench_runner/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ class PublishMirror:
4646
@dataclasses.dataclass
4747
class Benchmarks:
4848
# Benchmarks to exclude from plots.
49-
excluded_benchmarks: list[str] = dataclasses.field(default_factory=list)
49+
excluded_benchmarks: set[str] = dataclasses.field(default_factory=set)
50+
51+
def __post_init__(self):
52+
self.excluded_benchmarks = set(self.excluded_benchmarks)
5053

5154

5255
@dataclasses.dataclass

bench_runner/hpt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from numpy.typing import NDArray
3636

3737

38-
from . import util
38+
from . import config
3939
from .util import PathLike
4040

4141
ACC_MAXSU = 2
@@ -68,8 +68,9 @@ def load_data(data: Mapping[str, Any]) -> dict[str, NDArray[np.float64]]:
6868
def create_matrices(
6969
a: Mapping[str, NDArray[np.float64]], b: Mapping[str, NDArray[np.float64]]
7070
) -> tuple[dict[str, NDArray[np.float64]], dict[str, NDArray[np.float64]]]:
71+
cfg = config.get_config()
7172
benchmarks = sorted(list(set(a.keys()) & set(b.keys())))
72-
excluded = util.get_excluded_benchmarks()
73+
excluded = cfg.benchmarks.excluded_benchmarks
7374
benchmarks = [bm for bm in benchmarks if bm not in excluded]
7475
return {bm: a[bm] for bm in benchmarks}, {bm: b[bm] for bm in benchmarks}
7576

bench_runner/result.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ def calculate_diffs(ref_values, head_values) -> tuple[np.ndarray | None, float]:
205205
values.sort()
206206
return values, float(values.mean())
207207

208-
excluded = util.get_excluded_benchmarks()
208+
cfg = config.get_config()
209+
excluded = cfg.benchmarks.excluded_benchmarks
209210
combined_data = []
210211
for name, ref in ref_data.items():
211212
if len(ref) != 0 and name in head_data and name not in excluded:
@@ -689,8 +690,9 @@ def parsed_version(self):
689690
return pkg_version.parse(self.version.replace("+", "0"))
690691

691692
def get_timing_data(self) -> dict[str, np.ndarray]:
693+
cfg = config.get_config()
692694
data = {}
693-
excluded = util.get_excluded_benchmarks()
695+
excluded = cfg.benchmarks.excluded_benchmarks
694696

695697
for benchmark in self.contents["benchmarks"]:
696698
name = benchmark.get("metadata", self.contents["metadata"])["name"]
@@ -703,8 +705,9 @@ def get_timing_data(self) -> dict[str, np.ndarray]:
703705
return data
704706

705707
def get_memory_data(self) -> dict[str, np.ndarray]:
708+
cfg = config.get_config()
706709
data = {}
707-
excluded = util.get_excluded_benchmarks()
710+
excluded = cfg.benchmarks.excluded_benchmarks
708711

709712
# On MacOS, there was a bug in pyperf where the `mem_max_rss` value was
710713
# erroneously multiplied by 1024. (BSD defines maxrss in bytes, Linux

bench_runner/scripts/run_benchmarks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020

2121
from bench_runner import benchmark_definitions
22+
from bench_runner import config
2223
from bench_runner import flags
2324
from bench_runner import git
2425
from bench_runner.result import Result
@@ -366,9 +367,10 @@ def run_summarize_stats(
366367

367368

368369
def select_benchmarks(benchmarks: str):
370+
cfg = config.get_config()
369371
if benchmarks == "all":
370372
return ",".join(
371-
["all", *[f"-{x}" for x in util.get_excluded_benchmarks() if x]]
373+
["all", *[f"-{x}" for x in cfg.benchmarks.excluded_benchmarks if x]]
372374
)
373375
elif benchmarks == "all_and_excluded":
374376
return "all"

bench_runner/util.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,6 @@ def apply_suffix(path: PathLike, suffix: str) -> Path:
2727
return path_.parent / (path_.stem + suffix)
2828

2929

30-
@functools.cache
31-
def get_excluded_benchmarks() -> set[str]:
32-
from . import config
33-
34-
conf = config.get_config()
35-
excluded_benchmarks = conf.benchmarks.excluded_benchmarks
36-
return set(excluded_benchmarks)
37-
38-
3930
def has_any_element(iterable):
4031
"""
4132
Checks if an iterable (like a generator) has at least one element

0 commit comments

Comments
 (0)