|
1 | 1 | from __future__ import annotations |
2 | | - |
3 | 2 | import typing as t |
4 | | - |
5 | 3 | from sqlmesh.core.config.linter import LinterConfig |
6 | | - |
7 | 4 | from sqlmesh.core.model import Model |
8 | | - |
9 | 5 | from sqlmesh.utils.errors import raise_config_error |
10 | 6 | from sqlmesh.core.console import LinterConsole, get_console |
11 | | -from sqlmesh.core.linter.rule import RuleSet |
| 7 | +import operator as op |
| 8 | +from collections.abc import Iterator, Iterable, Set, Mapping, Callable |
| 9 | +from functools import reduce |
| 10 | +from sqlmesh.core.model import Model |
| 11 | +from sqlmesh.core.linter.rule import Rule, RuleViolation |
12 | 12 |
|
13 | 13 |
|
14 | 14 | def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet: |
@@ -70,3 +70,50 @@ def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bo |
70 | 70 | return True |
71 | 71 |
|
72 | 72 | return False |
| 73 | + |
| 74 | + |
| 75 | +class RuleSet(Mapping[str, type[Rule]]): |
| 76 | + def __init__(self, rules: Iterable[type[Rule]] = ()) -> None: |
| 77 | + self._underlying = {rule.name: rule for rule in rules} |
| 78 | + |
| 79 | + def check_model(self, model: Model) -> t.List[RuleViolation]: |
| 80 | + violations = [] |
| 81 | + |
| 82 | + for rule in self._underlying.values(): |
| 83 | + violation = rule().check_model(model) |
| 84 | + |
| 85 | + if violation: |
| 86 | + violations.append(violation) |
| 87 | + |
| 88 | + return violations |
| 89 | + |
| 90 | + def __iter__(self) -> Iterator[str]: |
| 91 | + return iter(self._underlying) |
| 92 | + |
| 93 | + def __len__(self) -> int: |
| 94 | + return len(self._underlying) |
| 95 | + |
| 96 | + def __getitem__(self, rule: str | type[Rule]) -> type[Rule]: |
| 97 | + key = rule if isinstance(rule, str) else rule.name |
| 98 | + return self._underlying[key] |
| 99 | + |
| 100 | + def __op( |
| 101 | + self, |
| 102 | + op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]], |
| 103 | + other: RuleSet, |
| 104 | + /, |
| 105 | + ) -> RuleSet: |
| 106 | + rules = set() |
| 107 | + for rule in op(set(self.values()), set(other.values())): |
| 108 | + rules.add(other[rule] if rule in other else self[rule]) |
| 109 | + |
| 110 | + return RuleSet(rules) |
| 111 | + |
| 112 | + def union(self, *others: RuleSet) -> RuleSet: |
| 113 | + return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others)) |
| 114 | + |
| 115 | + def intersection(self, *others: RuleSet) -> RuleSet: |
| 116 | + return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others)) |
| 117 | + |
| 118 | + def difference(self, *others: RuleSet) -> RuleSet: |
| 119 | + return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others)) |
0 commit comments