Skip to content

Commit f8c517a

Browse files
committed
feat: add linting to the lsp
1 parent 1bc4661 commit f8c517a

4 files changed

Lines changed: 132 additions & 9 deletions

File tree

sqlmesh/core/context.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from sqlmesh.core.engine_adapter import EngineAdapter
7777
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
78+
from sqlmesh.core.linter.rule import RuleViolationWithModelAndType
7879
from sqlmesh.core.loader import Loader
7980
from sqlmesh.core.linter.definition import Linter
8081
from sqlmesh.core.linter.rules import BUILTIN_RULES
@@ -2460,7 +2461,8 @@ def _get_models_for_interval_end(
24602461
def lint_models(
24612462
self,
24622463
models: t.Optional[t.Iterable[t.Union[str, Model]]] = None,
2463-
) -> None:
2464+
raise_on_error: bool = True,
2465+
) -> t.List[RuleViolationWithModelAndType]:
24642466
found_error = False
24652467

24662468
model_list = (
@@ -2469,13 +2471,17 @@ def lint_models(
24692471
for model in model_list:
24702472
# Linter may be `None` if the context is not loaded yet
24712473
if linter := self._linters.get(model.project):
2472-
found_error = linter.lint_model(model, console=self.console) or found_error
2474+
found_error, all_violations = (
2475+
linter.lint_model(model, console=self.console) or found_error
2476+
)
24732477

2474-
if found_error:
2478+
if raise_on_error and found_error:
24752479
raise LinterError(
24762480
"Linter detected errors in the code. Please fix them before proceeding."
24772481
)
24782482

2483+
return all_violations
2484+
24792485
def load_model_tests(
24802486
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
24812487
) -> t.List[ModelTestMetadata]:

sqlmesh/core/linter/definition.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sqlmesh.utils.errors import raise_config_error
1010
from sqlmesh.core.console import LinterConsole, get_console, Console
11-
from sqlmesh.core.linter.rule import RuleSet
11+
from sqlmesh.core.linter.rule import RuleSet, RuleViolationWithModelAndType
1212

1313

1414
def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
@@ -50,9 +50,11 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
5050

5151
return Linter(config.enabled, all_rules, rules, warn_rules)
5252

53-
def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bool:
53+
def lint_model(
54+
self, model: Model, console: LinterConsole = get_console()
55+
) -> t.Tuple[bool, t.List[RuleViolationWithModelAndType]]:
5456
if not self.enabled:
55-
return False
57+
return False, []
5658

5759
ignored_rules = select_rules(self.all_rules, model.ignored_rules)
5860

@@ -62,11 +64,28 @@ def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bo
6264
error_violations = rules.check_model(model)
6365
warn_violations = warn_rules.check_model(model)
6466

67+
all_violations: t.List[RuleViolationWithModelAndType] = [
68+
RuleViolationWithModelAndType(
69+
rule=violation.rule,
70+
violation_msg=violation.violation_msg,
71+
model=model,
72+
violation_type="error",
73+
)
74+
for violation in error_violations
75+
] + [
76+
RuleViolationWithModelAndType(
77+
rule=violation.rule,
78+
violation_msg=violation.violation_msg,
79+
model=model,
80+
violation_type="warning",
81+
)
82+
for violation in warn_violations
83+
]
84+
6585
if warn_violations:
6686
console.show_linter_violations(warn_violations, model)
67-
6887
if error_violations:
6988
console.show_linter_violations(error_violations, model, is_error=True)
70-
return True
89+
return True, all_violations
7190

72-
return False
91+
return False, all_violations

sqlmesh/core/linter/rule.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ def __repr__(self) -> str:
5050
return f"{self.rule.name}: {self.violation_msg}"
5151

5252

53+
class RuleViolationWithModelAndType(RuleViolation):
54+
def __init__(
55+
self,
56+
rule: Rule,
57+
violation_msg: str,
58+
model: Model,
59+
violation_type: t.Literal["error", "warning"],
60+
) -> None:
61+
super().__init__(rule, violation_msg)
62+
self.model = model
63+
self.violation_type = violation_type
64+
65+
5366
class RuleSet(Mapping[str, type[Rule]]):
5467
def __init__(self, rules: Iterable[type[Rule]] = ()) -> None:
5568
self._underlying = {rule.name: rule for rule in rules}

sqlmesh/lsp/main.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlmesh._version import __version__
1515
from sqlmesh.core.audit.definition import ModelAudit
1616
from sqlmesh.core.context import Context
17+
from sqlmesh.core.linter.rule import RuleViolationWithModelAndType
1718
from sqlmesh.core.model import Model
1819

1920

@@ -32,13 +33,78 @@ def __init__(
3233
self.server = LanguageServer(server_name, version)
3334
self.context_class = context_class
3435
self.context: t.Optional[Context] = None
36+
self.lint_cache: t.Dict[str, t.List[RuleViolationWithModelAndType]] = {}
3537

3638
# Register LSP features (e.g., formatting, hover, etc.)
3739
self._register_features()
3840

3941
def _register_features(self) -> None:
4042
"""Register LSP features on the internal LanguageServer instance."""
4143

44+
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
45+
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
46+
if self.context is None:
47+
self.ensure_context_for_document(
48+
ls.workspace.get_document(params.text_document.uri)
49+
)
50+
if self.context is None:
51+
raise RuntimeError("No context found")
52+
self.lint_cache[params.text_document.uri] = self.context.lint_models(
53+
[params.text_document.uri]
54+
)
55+
ls.publish_diagnostics(
56+
params.text_document.uri,
57+
[
58+
self._diagnostic_to_lsp_diagnostic(diagnostic)
59+
for diagnostic in self.lint_cache[params.text_document.uri]
60+
if diagnostic is not None
61+
],
62+
)
63+
64+
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
65+
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
66+
if self.context is None:
67+
self.ensure_context_for_document(
68+
ls.workspace.get_document(params.text_document.uri)
69+
)
70+
if self.context is None:
71+
raise RuntimeError("No context found")
72+
self.lint_cache[params.text_document.uri] = self.context.lint_models(
73+
[params.text_document.uri]
74+
)
75+
ls.publish_diagnostics(
76+
params.text_document.uri,
77+
[
78+
self._diagnostic_to_lsp_diagnostic(diagnostic)
79+
for diagnostic in self.lint_cache[params.text_document.uri]
80+
if diagnostic is not None
81+
],
82+
)
83+
84+
@self.server.feature(types.TEXT_DOCUMENT_DID_CLOSE)
85+
def did_close(ls: LanguageServer, params: types.DidCloseTextDocumentParams) -> None:
86+
self.lint_cache.pop(params.text_document.uri, None)
87+
88+
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
89+
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
90+
if self.context is None:
91+
self.ensure_context_for_document(
92+
ls.workspace.get_document(params.text_document.uri)
93+
)
94+
if self.context is None:
95+
raise RuntimeError("No context found")
96+
self.lint_cache[params.text_document.uri] = self.context.lint_models(
97+
[params.text_document.uri]
98+
)
99+
ls.publish_diagnostics(
100+
params.text_document.uri,
101+
[
102+
self._diagnostic_to_lsp_diagnostic(diagnostic)
103+
for diagnostic in self.lint_cache[params.text_document.uri]
104+
if diagnostic is not None
105+
],
106+
)
107+
42108
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
43109
def formatting(
44110
ls: LanguageServer, params: types.DocumentFormattingParams
@@ -110,6 +176,25 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
110176

111177
return document
112178

179+
@staticmethod
180+
def _diagnostic_to_lsp_diagnostic(
181+
diagnostic: RuleViolationWithModelAndType,
182+
) -> t.Optional[types.Diagnostic]:
183+
if diagnostic.model._path is None:
184+
return None
185+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
186+
lines = file.readlines()
187+
return types.Diagnostic(
188+
range=types.Range(
189+
start=types.Position(line=0, character=0),
190+
end=types.Position(line=len(lines), character=len(lines[-1])),
191+
),
192+
message=diagnostic.violation_msg,
193+
severity=types.DiagnosticSeverity.Error
194+
if diagnostic.violation_type == "error"
195+
else types.DiagnosticSeverity.Warning,
196+
)
197+
113198
def start(self) -> None:
114199
"""Start the server with I/O transport."""
115200
logging.basicConfig(level=logging.DEBUG)

0 commit comments

Comments
 (0)