Skip to content

Commit ce02caf

Browse files
committed
feat: make linter return violations
[ci skip]
1 parent 1865bd9 commit ce02caf

6 files changed

Lines changed: 26 additions & 13 deletions

File tree

examples/multi/repo_1/linter/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ class NoMissingDescription(Rule):
1212
"""All models should be documented."""
1313

1414
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
15-
return self.violation() if not model.description else None
15+
return self.violation(model) if not model.description else None

examples/sushi/linter/user.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ class NoMissingOwner(Rule):
1010
"""All models should have an owner specified."""
1111

1212
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
13-
return self.violation() if not model.owner else None
13+
return self.violation(model) if not model.owner else None

sqlmesh/core/context.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from sqlmesh.core.engine_adapter import EngineAdapter
7777
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
78+
from sqlmesh.core.linter.rule import RuleViolation
7879
from sqlmesh.core.loader import Loader
7980
from sqlmesh.core.linter.definition import Linter
8081
from sqlmesh.core.linter.rules import BUILTIN_RULES
@@ -2460,7 +2461,12 @@ def _get_models_for_interval_end(
24602461
def lint_models(
24612462
self,
24622463
models: t.Optional[t.Iterable[t.Union[str, Model]]] = None,
2463-
) -> None:
2464+
raise_error_on_error: bool = True,
2465+
) -> t.List[RuleViolation]:
2466+
"""
2467+
Lint models and raise an error if any errors are found, optionally the error can be ignored and the violations can be returned.
2468+
"""
2469+
violations = []
24642470
found_error = False
24652471

24662472
model_list = (
@@ -2471,7 +2477,7 @@ def lint_models(
24712477
if linter := self._linters.get(model.project):
24722478
found_error = linter.lint_model(model, console=self.console) or found_error
24732479

2474-
if found_error:
2480+
if raise_error_on_error and found_error:
24752481
raise LinterError(
24762482
"Linter detected errors in the code. Please fix them before proceeding."
24772483
)

sqlmesh/core/linter/definition.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sqlmesh.utils.errors import raise_config_error
1010
from sqlmesh.core.console import LinterConsole, get_console, Console
11-
from sqlmesh.core.linter.rule import RuleSet
11+
from sqlmesh.core.linter.rule import RuleSet, RuleViolation
1212

1313

1414
def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
@@ -50,7 +50,12 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
5050

5151
return Linter(config.enabled, all_rules, rules, warn_rules)
5252

53-
def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bool:
53+
def lint_model(self, model: Model, console: LinterConsole = get_console()) -> t.Tuple[
54+
bool, t.List[RuleViolation]
55+
]:
56+
"""
57+
Lint a model and return a boolean indicating whether the model has any error violations. It also returns a list of violations.
58+
"""
5459
if not self.enabled:
5560
return False
5661

sqlmesh/core/linter/rule.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,20 @@ def summary(self) -> str:
3333
"""A summary of what this rule checks for."""
3434
return self.__doc__ or ""
3535

36-
def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation:
36+
def violation(self, model: Model, violation_type: t.Literal["error", "warning"], violation_msg: t.Optional[str] = None, ) -> RuleViolation:
3737
"""Create a RuleViolation instance for this rule"""
38-
return RuleViolation(rule=self, violation_msg=violation_msg or self.summary)
38+
return RuleViolation(rule=self, violation_msg=violation_msg or self.summary, model=model, violation_type=violation_type)
3939

4040
def __repr__(self) -> str:
4141
return self.name
4242

4343

4444
class RuleViolation:
45-
def __init__(self, rule: Rule, violation_msg: str) -> None:
45+
def __init__(self, rule: Rule, violation_msg: str, model: Model, violation_type: t.Literal["error", "warning"]) -> None:
4646
self.rule = rule
47+
self.violation_type: t.Literal["error", "warning"] = violation_type
4748
self.violation_msg = violation_msg
49+
self.model = model
4850

4951
def __repr__(self) -> str:
5052
return f"{self.rule.name}: {self.violation_msg}"

sqlmesh/core/linter/rules/builtin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
1717
if not isinstance(model, SqlModel):
1818
return None
1919

20-
return self.violation() if model.query.is_star else None
20+
return self.violation(model) if model.query.is_star else None
2121

2222

2323
class InvalidSelectStarExpansion(Rule):
@@ -32,7 +32,7 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
3232
f"'{model.fqn}' can be rendered at parse time."
3333
)
3434

35-
return self.violation(violation_msg)
35+
return self.violation(model, violation_msg)
3636

3737

3838
class AmbiguousOrInvalidColumn(Rule):
@@ -45,14 +45,14 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
4545
f"{sqlglot_err} for model '{model.fqn}', the column may not exist or is ambiguous."
4646
)
4747

48-
return self.violation(violation_msg)
48+
return self.violation(model, violation_msg)
4949

5050

5151
class NoMissingAudits(Rule):
5252
"""Model `audits` must be configured to test data quality."""
5353

5454
def check_model(self, model: Model) -> t.Optional[RuleViolation]:
55-
return self.violation() if not model.audits and not model.kind.is_symbolic else None
55+
return self.violation(model) if not model.audits and not model.kind.is_symbolic else None
5656

5757

5858
BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, (Rule,)))

0 commit comments

Comments
 (0)