Skip to content

Commit d1c5493

Browse files
committed
feat: add linting to the lsp
1 parent 1bc4661 commit d1c5493

5 files changed

Lines changed: 190 additions & 15 deletions

File tree

examples/sushi/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PlanConfig,
1414
SparkConnectionConfig,
1515
)
16+
from sqlmesh.core.config.linter import LinterConfig
1617
from sqlmesh.core.notification_target import (
1718
BasicSMTPNotificationTarget,
1819
SlackApiNotificationTarget,
@@ -41,6 +42,15 @@
4142
},
4243
default_gateway="duckdb",
4344
model_defaults=model_defaults,
45+
linter=LinterConfig(
46+
enabled=True,
47+
rules=[
48+
"ambiguousorinvalidcolumn",
49+
"invalidselectstarexpansion",
50+
"noselectstar",
51+
"nomissingaudits",
52+
],
53+
),
4454
)
4555

4656
bigquery_config = Config(

sqlmesh/core/context.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
)
7676
from sqlmesh.core.engine_adapter import EngineAdapter
7777
from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements
78+
from sqlmesh.core.linter.rule import RuleViolationWithModelAndType
7879
from sqlmesh.core.loader import Loader
7980
from sqlmesh.core.linter.definition import Linter
8081
from sqlmesh.core.linter.rules import BUILTIN_RULES
@@ -2460,7 +2461,8 @@ def _get_models_for_interval_end(
24602461
def lint_models(
24612462
self,
24622463
models: t.Optional[t.Iterable[t.Union[str, Model]]] = None,
2463-
) -> None:
2464+
raise_on_error: bool = True,
2465+
) -> t.List[RuleViolationWithModelAndType]:
24642466
found_error = False
24652467

24662468
model_list = (
@@ -2469,13 +2471,17 @@ def lint_models(
24692471
for model in model_list:
24702472
# Linter may be `None` if the context is not loaded yet
24712473
if linter := self._linters.get(model.project):
2472-
found_error = linter.lint_model(model, console=self.console) or found_error
2474+
found_error, all_violations = (
2475+
linter.lint_model(model, console=self.console) or found_error
2476+
)
24732477

2474-
if found_error:
2478+
if raise_on_error and found_error:
24752479
raise LinterError(
24762480
"Linter detected errors in the code. Please fix them before proceeding."
24772481
)
24782482

2483+
return all_violations
2484+
24792485
def load_model_tests(
24802486
self, tests: t.Optional[t.List[str]] = None, patterns: list[str] | None = None
24812487
) -> t.List[ModelTestMetadata]:

sqlmesh/core/linter/definition.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sqlmesh.utils.errors import raise_config_error
1010
from sqlmesh.core.console import LinterConsole, get_console, Console
11-
from sqlmesh.core.linter.rule import RuleSet
11+
from sqlmesh.core.linter.rule import RuleSet, RuleViolationWithModelAndType
1212

1313

1414
def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet:
@@ -50,9 +50,11 @@ def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter:
5050

5151
return Linter(config.enabled, all_rules, rules, warn_rules)
5252

53-
def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bool:
53+
def lint_model(
54+
self, model: Model, console: LinterConsole = get_console()
55+
) -> t.Tuple[bool, t.List[RuleViolationWithModelAndType]]:
5456
if not self.enabled:
55-
return False
57+
return False, []
5658

5759
ignored_rules = select_rules(self.all_rules, model.ignored_rules)
5860

@@ -62,11 +64,28 @@ def lint_model(self, model: Model, console: LinterConsole = get_console()) -> bo
6264
error_violations = rules.check_model(model)
6365
warn_violations = warn_rules.check_model(model)
6466

67+
all_violations: t.List[RuleViolationWithModelAndType] = [
68+
RuleViolationWithModelAndType(
69+
rule=violation.rule,
70+
violation_msg=violation.violation_msg,
71+
model=model,
72+
violation_type="error",
73+
)
74+
for violation in error_violations
75+
] + [
76+
RuleViolationWithModelAndType(
77+
rule=violation.rule,
78+
violation_msg=violation.violation_msg,
79+
model=model,
80+
violation_type="warning",
81+
)
82+
for violation in warn_violations
83+
]
84+
6585
if warn_violations:
6686
console.show_linter_violations(warn_violations, model)
67-
6887
if error_violations:
6988
console.show_linter_violations(error_violations, model, is_error=True)
70-
return True
89+
return True, all_violations
7190

72-
return False
91+
return False, all_violations

sqlmesh/core/linter/rule.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,19 @@ def __repr__(self) -> str:
5050
return f"{self.rule.name}: {self.violation_msg}"
5151

5252

53+
class RuleViolationWithModelAndType(RuleViolation):
54+
def __init__(
55+
self,
56+
rule: Rule,
57+
violation_msg: str,
58+
model: Model,
59+
violation_type: t.Literal["error", "warning"],
60+
) -> None:
61+
super().__init__(rule, violation_msg)
62+
self.model = model
63+
self.violation_type = violation_type
64+
65+
5366
class RuleSet(Mapping[str, type[Rule]]):
5467
def __init__(self, rules: Iterable[type[Rule]] = ()) -> None:
5568
self._underlying = {rule.name: rule for rule in rules}

sqlmesh/lsp/main.py

Lines changed: 133 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlmesh._version import __version__
1515
from sqlmesh.core.audit.definition import ModelAudit
1616
from sqlmesh.core.context import Context
17+
from sqlmesh.core.linter.rule import RuleViolationWithModelAndType
1718
from sqlmesh.core.model import Model
1819

1920

@@ -31,14 +32,85 @@ def __init__(
3132
"""
3233
self.server = LanguageServer(server_name, version)
3334
self.context_class = context_class
34-
self.context: t.Optional[Context] = None
35+
# A tuple of (context, reverse_map) where the reverse_map is uri to model name
36+
self.context_and_reverse_map: t.Optional[t.Tuple[Context, t.Dict[str, str]]] = None
37+
self.lint_cache: t.Dict[str, t.List[RuleViolationWithModelAndType]] = {}
3538

3639
# Register LSP features (e.g., formatting, hover, etc.)
3740
self._register_features()
3841

3942
def _register_features(self) -> None:
4043
"""Register LSP features on the internal LanguageServer instance."""
4144

45+
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
46+
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
47+
if self.context_and_reverse_map is None:
48+
self.ensure_context_for_document(
49+
ls.workspace.get_document(params.text_document.uri)
50+
)
51+
if self.context_and_reverse_map is None:
52+
raise RuntimeError("No context found")
53+
model = self.context_and_reverse_map[1][params.text_document.uri]
54+
if model is None:
55+
return
56+
self.lint_cache[params.text_document.uri] = self.context_and_reverse_map[0].lint_models(
57+
[model],
58+
raise_on_error=False,
59+
)
60+
ls.publish_diagnostics(
61+
params.text_document.uri,
62+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
63+
self.lint_cache[params.text_document.uri]
64+
),
65+
)
66+
67+
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
68+
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
69+
if self.context_and_reverse_map is None:
70+
self.ensure_context_for_document(
71+
ls.workspace.get_document(params.text_document.uri)
72+
)
73+
if self.context_and_reverse_map is None:
74+
raise RuntimeError("No context found")
75+
model = self.context_and_reverse_map[1][params.text_document.uri]
76+
if model is None:
77+
return
78+
self.lint_cache[params.text_document.uri] = self.context_and_reverse_map[0].lint_models(
79+
[model],
80+
raise_on_error=False,
81+
)
82+
ls.publish_diagnostics(
83+
params.text_document.uri,
84+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
85+
self.lint_cache[params.text_document.uri]
86+
),
87+
)
88+
89+
@self.server.feature(types.TEXT_DOCUMENT_DID_CLOSE)
90+
def did_close(ls: LanguageServer, params: types.DidCloseTextDocumentParams) -> None:
91+
self.lint_cache.pop(params.text_document.uri, None)
92+
93+
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
94+
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
95+
if self.context_and_reverse_map is None:
96+
self.ensure_context_for_document(
97+
ls.workspace.get_document(params.text_document.uri)
98+
)
99+
if self.context_and_reverse_map is None:
100+
raise RuntimeError("No context found")
101+
model = self.context_and_reverse_map[1][params.text_document.uri]
102+
if model is None:
103+
return
104+
self.lint_cache[params.text_document.uri] = self.context_and_reverse_map[0].lint_models(
105+
[model]
106+
)
107+
ls.publish_diagnostics(
108+
params.text_document.uri,
109+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
110+
self.lint_cache[params.text_document.uri]
111+
),
112+
)
113+
42114
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
43115
def formatting(
44116
ls: LanguageServer, params: types.DocumentFormattingParams
@@ -49,11 +121,11 @@ def formatting(
49121
ls.workspace.get_document(params.text_document.uri)
50122
)
51123

52-
if self.context is None:
124+
if self.context_and_reverse_map is None:
53125
raise RuntimeError(f"No context found for document: {document.path}")
54126

55127
# Perform formatting using the loaded context
56-
self.context.format(paths=(Path(document.path),))
128+
self.context_and_reverse_map[0].format(paths=(Path(document.path),))
57129
with open(document.path, "r+", encoding="utf-8") as file:
58130
new_text = file.read()
59131

@@ -80,8 +152,16 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
80152
for a config.py or config.yml file in the parent directories.
81153
"""
82154
# If the context is already loaded, check if this document belongs to it.
83-
if self.context is not None:
84-
self.context.load() # Reload or refresh context
155+
if self.context_and_reverse_map is not None:
156+
context, _ = self.context_and_reverse_map
157+
context.load() # Reload or refresh context
158+
self.context_and_reverse_map = (
159+
context,
160+
{
161+
f"file://{Path(model._path).resolve().as_posix()}": model.name
162+
for model in context._models.values()
163+
},
164+
)
85165
return document
86166

87167
# No context yet: try to find config and load it
@@ -97,7 +177,15 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
97177
if config_path.exists():
98178
try:
99179
# Use user-provided instantiator to build the context
100-
self.context = self.context_class(paths=[path])
180+
context = self.context_class(paths=[path])
181+
self.context_and_reverse_map = (
182+
context,
183+
{
184+
f"file://{Path(model._path).resolve().as_posix()}": model.name
185+
for model in context._models.values()
186+
if model._path is not None
187+
},
188+
)
101189
self.server.show_message(f"Context loaded for: {path}")
102190
loaded = True
103191
# Re-check context for document now that it's loaded
@@ -110,6 +198,45 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
110198

111199
return document
112200

201+
@staticmethod
202+
def _diagnostic_to_lsp_diagnostic(
203+
diagnostic: RuleViolationWithModelAndType,
204+
) -> t.Optional[types.Diagnostic]:
205+
if diagnostic.model._path is None:
206+
return None
207+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
208+
lines = file.readlines()
209+
return types.Diagnostic(
210+
range=types.Range(
211+
start=types.Position(line=0, character=0),
212+
end=types.Position(line=len(lines), character=len(lines[-1])),
213+
),
214+
message=diagnostic.violation_msg,
215+
severity=types.DiagnosticSeverity.Error
216+
if diagnostic.violation_type == "error"
217+
else types.DiagnosticSeverity.Warning,
218+
)
219+
220+
@staticmethod
221+
def _diagnostics_to_lsp_diagnostics(
222+
diagnostics: t.List[RuleViolationWithModelAndType],
223+
) -> t.List[types.Diagnostic]:
224+
lsp_diagnostics: t.List[types.Diagnostic] = []
225+
for diagnostic in diagnostics:
226+
if diagnostic is None:
227+
continue
228+
lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic(diagnostic)
229+
if lsp_diagnostic is not None:
230+
lsp_diagnostics.append(lsp_diagnostic)
231+
return lsp_diagnostics
232+
233+
@staticmethod
234+
def _uri_to_path(uri: str) -> str:
235+
"""Convert a URI to a path."""
236+
if uri.startswith("file://"):
237+
return Path(uri[7:]).resolve().as_posix()
238+
return Path(uri).resolve().as_posix()
239+
113240
def start(self) -> None:
114241
"""Start the server with I/O transport."""
115242
logging.basicConfig(level=logging.DEBUG)

0 commit comments

Comments
 (0)