Skip to content

Commit 94b15f1

Browse files
committed
refactor: move ruleset into linter definition
1 parent dafb0c7 commit 94b15f1

4 files changed

Lines changed: 56 additions & 57 deletions

File tree

sqlmesh/core/linter/definition.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
2-
32
import typing as t
4-
53
from sqlmesh.core.config.linter import LinterConfig
6-
74
from sqlmesh.core.model import Model
8-
95
from sqlmesh.utils.errors import raise_config_error
106
from sqlmesh.core.console import LinterConsole, get_console
11-
from sqlmesh.core.linter.rule import RuleSet
7+
import operator as op
8+
from collections.abc import Iterator, Iterable, Set, Mapping, Callable
9+
from functools import reduce
10+
from sqlmesh.core.model import Model
11+
from sqlmesh.core.linter.rule import Rule, RuleViolation
1212

1313

1414
def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
@@ -70,3 +70,50 @@ def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bo
7070
return True
7171

7272
return False
73+
74+
75+
class RuleSet(Mapping[str, type[Rule]]):
76+
def __init__(self, rules: Iterable[type[Rule]] = ()) -> None:
77+
self._underlying = {rule.name: rule for rule in rules}
78+
79+
def check_model(self, model: Model) -> t.List[RuleViolation]:
80+
violations = []
81+
82+
for rule in self._underlying.values():
83+
violation = rule().check_model(model)
84+
85+
if violation:
86+
violations.append(violation)
87+
88+
return violations
89+
90+
def __iter__(self) -> Iterator[str]:
91+
return iter(self._underlying)
92+
93+
def __len__(self) -> int:
94+
return len(self._underlying)
95+
96+
def __getitem__(self, rule: str | type[Rule]) -> type[Rule]:
97+
key = rule if isinstance(rule, str) else rule.name
98+
return self._underlying[key]
99+
100+
def __op(
101+
self,
102+
op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]],
103+
other: RuleSet,
104+
/,
105+
) -> RuleSet:
106+
rules = set()
107+
for rule in op(set(self.values()), set(other.values())):
108+
rules.add(other[rule] if rule in other else self[rule])
109+
110+
return RuleSet(rules)
111+
112+
def union(self, *others: RuleSet) -> RuleSet:
113+
return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others))
114+
115+
def intersection(self, *others: RuleSet) -> RuleSet:
116+
return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others))
117+
118+
def difference(self, *others: RuleSet) -> RuleSet:
119+
return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))

sqlmesh/core/linter/rule.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22

33
import abc
44

5-
import operator as op
6-
from collections.abc import Iterator, Iterable, Set, Mapping, Callable
7-
from functools import reduce
85

96
from sqlmesh.core.model import Model
107

@@ -48,50 +45,3 @@ def __init__(self, rule: Rule, violation_msg: str) -> None:
4845

4946
def __repr__(self) -> str:
5047
return f"{self.rule.name}: {self.violation_msg}"
51-
52-
53-
class RuleSet(Mapping[str, type[Rule]]):
54-
def __init__(self, rules: Iterable[type[Rule]] = ()) -> None:
55-
self._underlying = {rule.name: rule for rule in rules}
56-
57-
def check_model(self, model: Model) -> t.List[RuleViolation]:
58-
violations = []
59-
60-
for rule in self._underlying.values():
61-
violation = rule().check_model(model)
62-
63-
if violation:
64-
violations.append(violation)
65-
66-
return violations
67-
68-
def __iter__(self) -> Iterator[str]:
69-
return iter(self._underlying)
70-
71-
def __len__(self) -> int:
72-
return len(self._underlying)
73-
74-
def __getitem__(self, rule: str | type[Rule]) -> type[Rule]:
75-
key = rule if isinstance(rule, str) else rule.name
76-
return self._underlying[key]
77-
78-
def __op(
79-
self,
80-
op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]],
81-
other: RuleSet,
82-
/,
83-
) -> RuleSet:
84-
rules = set()
85-
for rule in op(set(self.values()), set(other.values())):
86-
rules.add(other[rule] if rule in other else self[rule])
87-
88-
return RuleSet(rules)
89-
90-
def union(self, *others: RuleSet) -> RuleSet:
91-
return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others))
92-
93-
def intersection(self, *others: RuleSet) -> RuleSet:
94-
return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others))
95-
96-
def difference(self, *others: RuleSet) -> RuleSet:
97-
return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))

sqlmesh/core/linter/rules/builtin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from sqlglot.helper import subclasses
88

9-
from sqlmesh.core.linter.rule import Rule, RuleViolation, RuleSet
9+
from sqlmesh.core.linter.rule import Rule, RuleViolation
10+
from sqlmesh.core.linter.definition import RuleSet
1011
from sqlmesh.core.model import Model, SqlModel
1112

1213

sqlmesh/core/loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit, load_multiple_audits
2121
from sqlmesh.core.dialect import parse
2222
from sqlmesh.core.environment import EnvironmentStatements
23-
from sqlmesh.core.linter.rule import RuleSet, Rule
23+
from sqlmesh.core.linter.rule import Rule
24+
from sqlmesh.core.linter.definition import RuleSet
2425
from sqlmesh.core.macros import MacroRegistry, macro
2526
from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl
2627
from sqlmesh.core.model import (

0 commit comments

Comments
 (0)