Skip to content

Commit bb73daf

Browse files
authored
feat: publish linter errors through lsp (#4152)
1 parent 83ba551 commit bb73daf

5 files changed

Lines changed: 203 additions & 29 deletions

File tree

examples/sushi/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PlanConfig,
1414
SparkConnectionConfig,
1515
)
16+
from sqlmesh.core.config.linter import LinterConfig
1617
from sqlmesh.core.notification_target import (
1718
BasicSMTPNotificationTarget,
1819
SlackApiNotificationTarget,
@@ -41,6 +42,15 @@
4142
},
4243
default_gateway="duckdb",
4344
model_defaults=model_defaults,
45+
linter=LinterConfig(
46+
enabled=False,
47+
rules=[
48+
"ambiguousorinvalidcolumn",
49+
"invalidselectstarexpansion",
50+
"noselectstar",
51+
"nomissingaudits",
52+
],
53+
),
4454
)
4555

4656
bigquery_config = Config(

sqlmesh/core/context.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
from sqlmesh.core.engine_adapter import EngineAdapter
7777
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
7878
from sqlmesh.core.loader import Loader
79-
from sqlmesh.core.linter.definition import Linter
79+
from sqlmesh.core.linter.definition import AnnotatedRuleViolation, Linter
8080
from sqlmesh.core.linter.rules import BUILTIN_RULES
8181
from sqlmesh.core.macros import ExecutableOrMacro, macro
8282
from sqlmesh.core.metric import Metric, rewrite
@@ -2466,22 +2466,31 @@ def _get_models_for_interval_end(
24662466
def lint_models(
24672467
self,
24682468
models: t.Optional[t.Iterable[t.Union[str, Model]]] = None,
2469-
) -> None:
2469+
raise_on_error: bool = True,
2470+
) -> t.List[AnnotatedRuleViolation]:
24702471
found_error = False
24712472

24722473
model_list = (
24732474
list(self.get_model(model) for model in models) if models else self.models.values()
24742475
)
2476+
all_violations = []
24752477
for model in model_list:
24762478
# Linter may be `None` if the context is not loaded yet
24772479
if linter := self._linters.get(model.project):
2478-
found_error = linter.lint_model(model, console=self.console) or found_error
2480+
lint_violation, violations = (
2481+
linter.lint_model(model, console=self.console) or found_error
2482+
)
2483+
if lint_violation:
2484+
found_error = True
2485+
all_violations.extend(violations)
24792486

2480-
if found_error:
2487+
if raise_on_error and found_error:
24812488
raise LinterError(
24822489
"Linter detected errors in the code. Please fix them before proceeding."
24832490
)
24842491

2492+
return all_violations
2493+
24852494
def load_model_tests(
24862495
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
24872496
) -> t.List[ModelTestMetadata]:

sqlmesh/core/linter/definition.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from functools import reduce
1010
from sqlmesh.core.model import Model
1111
from sqlmesh.core.linter.rule import Rule, RuleViolation
12+
from sqlmesh.core.console import LinterConsole, get_console
1213

1314

1415
def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
@@ -50,9 +51,11 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
5051

5152
return Linter(config.enabled, all_rules, rules, warn_rules)
5253

53-
def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bool:
54+
def lint_model(
55+
self, model: Model, console: LinterConsole = get_console()
56+
) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]:
5457
if not self.enabled:
55-
return False
58+
return False, []
5659

5760
ignored_rules = select_rules(self.all_rules, model.ignored_rules)
5861

@@ -62,14 +65,31 @@ def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bo
6265
error_violations = rules.check_model(model)
6366
warn_violations = warn_rules.check_model(model)
6467

68+
all_violations: t.List[AnnotatedRuleViolation] = [
69+
AnnotatedRuleViolation(
70+
rule=violation.rule,
71+
violation_msg=violation.violation_msg,
72+
model=model,
73+
violation_type="error",
74+
)
75+
for violation in error_violations
76+
] + [
77+
AnnotatedRuleViolation(
78+
rule=violation.rule,
79+
violation_msg=violation.violation_msg,
80+
model=model,
81+
violation_type="warning",
82+
)
83+
for violation in warn_violations
84+
]
85+
6586
if warn_violations:
6687
console.show_linter_violations(warn_violations, model)
67-
6888
if error_violations:
6989
console.show_linter_violations(error_violations, model, is_error=True)
70-
return True
90+
return True, all_violations
7191

72-
return False
92+
return False, all_violations
7393

7494

7595
class RuleSet(Mapping[str, type[Rule]]):
@@ -117,3 +137,16 @@ def intersection(self, *others: RuleSet) -> RuleSet:
117137

118138
def difference(self, *others: RuleSet) -> RuleSet:
119139
return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others))
140+
141+
142+
class AnnotatedRuleViolation(RuleViolation):
143+
def __init__(
144+
self,
145+
rule: Rule,
146+
violation_msg: str,
147+
model: Model,
148+
violation_type: t.Literal["error", "warning"],
149+
) -> None:
150+
super().__init__(rule, violation_msg)
151+
self.model = model
152+
self.violation_type = violation_type

sqlmesh/core/linter/rule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import abc
44

5-
65
from sqlmesh.core.model import Model
76

87
from typing import Type

sqlmesh/lsp/main.py

Lines changed: 142 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4+
from collections import defaultdict
45
import logging
56
import typing as t
67
from pathlib import Path
78

89
from lsprotocol import types
910
from pygls.server import LanguageServer
10-
from pygls.workspace import TextDocument
1111

1212
from sqlmesh._version import __version__
1313
from sqlmesh.core.context import Context
14+
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
15+
16+
17+
class LSPContext:
18+
"""
19+
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20+
"""
21+
22+
def __init__(self, context: Context) -> None:
23+
self.context = context
24+
map: t.Dict[str, t.List[str]] = defaultdict(list[str])
25+
for model in context.models.values():
26+
if model._path is None:
27+
path = Path(model._path).resolve()
28+
map[f"file://{path.as_posix()}"].append(model.name)
29+
30+
self.map = map
1431

1532

1633
class SQLMeshLanguageServer:
@@ -27,29 +44,87 @@ def __init__(
2744
"""
2845
self.server = LanguageServer(server_name, version)
2946
self.context_class = context_class
30-
self.context: t.Optional[Context] = None
47+
self.lsp_context: t.Optional[LSPContext] = None
48+
self.lint_cache: t.Dict[str, t.List[AnnotatedRuleViolation]] = {}
3149

3250
# Register LSP features (e.g., formatting, hover, etc.)
3351
self._register_features()
3452

3553
def _register_features(self) -> None:
3654
"""Register LSP features on the internal LanguageServer instance."""
3755

56+
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
57+
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
58+
context = self._context_get_or_load(params.text_document.uri)
59+
if self.lint_cache.get(params.text_document.uri) is not None:
60+
ls.publish_diagnostics(
61+
params.text_document.uri,
62+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
63+
self.lint_cache[params.text_document.uri]
64+
),
65+
)
66+
return
67+
models = context.map[params.text_document.uri]
68+
if models is None:
69+
return
70+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
71+
models,
72+
raise_on_error=False,
73+
)
74+
ls.publish_diagnostics(
75+
params.text_document.uri,
76+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
77+
self.lint_cache[params.text_document.uri]
78+
),
79+
)
80+
81+
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
82+
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
83+
context = self._context_get_or_load(params.text_document.uri)
84+
models = context.map[params.text_document.uri]
85+
if models is None:
86+
return
87+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
88+
models,
89+
raise_on_error=False,
90+
)
91+
ls.publish_diagnostics(
92+
params.text_document.uri,
93+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
94+
self.lint_cache[params.text_document.uri]
95+
),
96+
)
97+
98+
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
99+
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
100+
context = self._context_get_or_load(params.text_document.uri)
101+
models = context.map[params.text_document.uri]
102+
if models is None:
103+
return
104+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
105+
models,
106+
raise_on_error=False,
107+
)
108+
ls.publish_diagnostics(
109+
params.text_document.uri,
110+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
111+
self.lint_cache[params.text_document.uri]
112+
),
113+
)
114+
38115
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
39116
def formatting(
40117
ls: LanguageServer, params: types.DocumentFormattingParams
41118
) -> t.List[types.TextEdit]:
42119
"""Format the document using SQLMesh `format_model_expressions`."""
43120
try:
44-
document = self.ensure_context_for_document(
45-
ls.workspace.get_document(params.text_document.uri)
46-
)
47-
48-
if self.context is None:
121+
self._ensure_context_for_document(params.text_document.uri)
122+
document = ls.workspace.get_document(params.text_document.uri)
123+
if self.lsp_context is None:
49124
raise RuntimeError(f"No context found for document: {document.path}")
50125

51126
# Perform formatting using the loaded context
52-
self.context.format(paths=(Path(document.path),))
127+
self.lsp_context.context.format(paths=(Path(document.path),))
53128
with open(document.path, "r+", encoding="utf-8") as file:
54129
new_text = file.read()
55130

@@ -70,20 +145,31 @@ def formatting(
70145
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
71146
return []
72147

73-
def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
148+
def _context_get_or_load(self, document_uri: str) -> LSPContext:
149+
if self.lsp_context is None:
150+
self._ensure_context_for_document(document_uri)
151+
if self.lsp_context is None:
152+
raise RuntimeError("No context found")
153+
return self.lsp_context
154+
155+
def _ensure_context_for_document(
156+
self,
157+
document_uri: str,
158+
) -> None:
74159
"""
75160
Ensure that a context exists for the given document if applicable by searching
76161
for a config.py or config.yml file in the parent directories.
77162
"""
78-
# If the context is already loaded, check if this document belongs to it.
79-
if self.context is not None:
80-
self.context.load() # Reload or refresh context
81-
return document
163+
if self.lsp_context is not None:
164+
context = self.lsp_context
165+
context.context.load() # Reload or refresh context
166+
self.lsp_context = LSPContext(context.context)
167+
return
82168

83169
# No context yet: try to find config and load it
84-
path = Path(document.path).resolve()
170+
path = Path(self._uri_to_path(document_uri)).resolve()
85171
if path.suffix not in (".sql", ".py"):
86-
return document
172+
return
87173

88174
loaded = False
89175
# Ascend directories to look for config
@@ -93,18 +179,55 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
93179
if config_path.exists():
94180
try:
95181
# Use user-provided instantiator to build the context
96-
self.context = self.context_class(paths=[path])
97-
self.server.show_message(f"Context loaded for: {path}")
182+
created_context = self.context_class(paths=[path])
183+
self.lsp_context = LSPContext(created_context)
98184
loaded = True
99185
# Re-check context for document now that it's loaded
100-
return self.ensure_context_for_document(document)
186+
return self._ensure_context_for_document(document_uri)
101187
except Exception as e:
102188
self.server.show_message(
103189
f"Error loading context: {e}", types.MessageType.Error
104190
)
105191
path = path.parent
106192

107-
return document
193+
return
194+
195+
@staticmethod
196+
def _diagnostic_to_lsp_diagnostic(
197+
diagnostic: AnnotatedRuleViolation,
198+
) -> t.Optional[types.Diagnostic]:
199+
if diagnostic.model._path is None:
200+
return None
201+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
202+
lines = file.readlines()
203+
return types.Diagnostic(
204+
range=types.Range(
205+
start=types.Position(line=0, character=0),
206+
end=types.Position(line=len(lines), character=len(lines[-1])),
207+
),
208+
message=diagnostic.violation_msg,
209+
severity=types.DiagnosticSeverity.Error
210+
if diagnostic.violation_type == "error"
211+
else types.DiagnosticSeverity.Warning,
212+
)
213+
214+
@staticmethod
215+
def _diagnostics_to_lsp_diagnostics(
216+
diagnostics: t.List[AnnotatedRuleViolation],
217+
) -> t.List[types.Diagnostic]:
218+
lsp_diagnostics: t.List[types.Diagnostic] = []
219+
for diagnostic in diagnostics:
220+
lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic(diagnostic)
221+
if lsp_diagnostic is not None:
222+
lsp_diagnostics.append(lsp_diagnostic)
223+
return lsp_diagnostics
224+
225+
@staticmethod
226+
def _uri_to_path(uri: str) -> str:
227+
"""Convert a URI to a path."""
228+
if uri.startswith("file://"):
229+
return Path(uri[7:]).resolve().as_posix()
230+
return Path(uri).resolve().as_posix()
108231

109232
def start(self) -> None:
110233
"""Start the server with I/O transport."""

0 commit comments

Comments
 (0)