Skip to content

Commit 00a43e6

Browse files
committed
feat: add linting to the lsp
[ci skip]
1 parent a02d8b3 commit 00a43e6

4 files changed

Lines changed: 59 additions & 10 deletions

File tree

examples/sushi/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PlanConfig,
1414
SparkConnectionConfig,
1515
)
16+
from sqlmesh.core.config.linter import LinterConfig
1617
from sqlmesh.core.notification_target import (
1718
BasicSMTPNotificationTarget,
1819
SlackApiNotificationTarget,
@@ -41,6 +42,15 @@
4142
},
4243
default_gateway="duckdb",
4344
model_defaults=model_defaults,
45+
linter=LinterConfig(
46+
enabled=False,
47+
rules=[
48+
"ambiguousorinvalidcolumn",
49+
"invalidselectstarexpansion",
50+
"noselectstar",
51+
"nomissingaudits",
52+
],
53+
),
4454
)
4555

4656
bigquery_config = Config(

sqlmesh/core/context.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from sqlmesh.core.engine_adapter import EngineAdapter
7777
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
7878
from sqlmesh.core.loader import Loader
79-
from sqlmesh.core.linter.definition import Linter
79+
from sqlmesh.core.linter.definition import AnnotatedRuleViolation, Linter
8080
from sqlmesh.core.linter.rules import BUILTIN_RULES
8181
from sqlmesh.core.macros import ExecutableOrMacro, macro
8282
from sqlmesh.core.metric import Metric, rewrite
@@ -2461,22 +2461,29 @@ def _get_models_for_interval_end(
24612461
def lint_models(
24622462
self,
24632463
models: t.Optional[t.Iterable[t.Union[str, Model]]] = None,
2464-
) -> None:
2464+
raise_on_error: bool = True,
2465+
) -> t.List[AnnotatedRuleViolation]:
24652466
found_error = False
24662467

24672468
model_list = (
24682469
list(self.get_model(model) for model in models) if models else self.models.values()
24692470
)
2471+
all_violations = []
24702472
for model in model_list:
24712473
# Linter may be `None` if the context is not loaded yet
24722474
if linter := self._linters.get(model.project):
2473-
found_error = linter.lint_model(model, console=self.console) or found_error
2475+
found_error, violations = (
2476+
linter.lint_model(model, console=self.console) or found_error
2477+
)
2478+
all_violations.extend(violations)
24742479

2475-
if found_error:
2480+
if raise_on_error and found_error:
24762481
raise LinterError(
24772482
"Linter detected errors in the code. Please fix them before proceeding."
24782483
)
24792484

2485+
return all_violations
2486+
24802487
def load_model_tests(
24812488
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
24822489
) -> t.List[ModelTestMetadata]:

sqlmesh/core/linter/definition.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from functools import reduce
1010
from sqlmesh.core.model import Model
1111
from sqlmesh.core.linter.rule import Rule, RuleViolation
12+
from sqlmesh.core.console import LinterConsole, get_console
1213

1314

1415
def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
@@ -50,9 +51,11 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
5051

5152
return Linter(config.enabled, all_rules, rules, warn_rules)
5253

53-
def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bool:
54+
def lint_model(
55+
self, model: Model, console: LinterConsole = get_console()
56+
) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]:
5457
if not self.enabled:
55-
return False
58+
return False, []
5659

5760
ignored_rules = select_rules(self.all_rules, model.ignored_rules)
5861

@@ -62,14 +65,31 @@ def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bo
6265
error_violations = rules.check_model(model)
6366
warn_violations = warn_rules.check_model(model)
6467

68+
all_violations: t.List[AnnotatedRuleViolation] = [
69+
AnnotatedRuleViolation(
70+
rule=violation.rule,
71+
violation_msg=violation.violation_msg,
72+
model=model,
73+
violation_type="error",
74+
)
75+
for violation in error_violations
76+
] + [
77+
AnnotatedRuleViolation(
78+
rule=violation.rule,
79+
violation_msg=violation.violation_msg,
80+
model=model,
81+
violation_type="warning",
82+
)
83+
for violation in warn_violations
84+
]
85+
6586
if warn_violations:
6687
console.show_linter_violations(warn_violations, model)
67-
6888
if error_violations:
6989
console.show_linter_violations(error_violations, model, is_error=True)
70-
return True
90+
return True, all_violations
7191

72-
return False
92+
return False, all_violations
7393

7494

7595
class RuleSet(Mapping[str, type[Rule]]):
@@ -117,3 +137,16 @@ def intersection(self, *others: RuleSet) -> RuleSet:
117137

118138
def difference(self, *others: RuleSet) -> RuleSet:
119139
return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))
140+
141+
142+
class AnnotatedRuleViolation(RuleViolation):
143+
def __init__(
144+
self,
145+
rule: Rule,
146+
violation_msg: str,
147+
model: Model,
148+
violation_type: t.Literal["error", "warning"],
149+
) -> None:
150+
super().__init__(rule, violation_msg)
151+
self.model = model
152+
self.violation_type = violation_type

sqlmesh/core/linter/rule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import abc
44

5-
65
from sqlmesh.core.model import Model
76

87
from typing import Type

0 commit comments

Comments
 (0)