Skip to content

Commit 9a5041a

Browse files
committed
temp linting
1 parent 834a9a7 commit 9a5041a

2 files changed

Lines changed: 143 additions & 18 deletions

File tree

sqlmesh/core/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2477,9 +2477,11 @@ def lint_models(
24772477
for model in model_list:
24782478
# Linter may be `None` if the context is not loaded yet
24792479
if linter := self._linters.get(model.project):
2480-
found_error, violations = (
2480+
lint_violation, violations = (
24812481
linter.lint_model(model, console=self.console) or found_error
24822482
)
2483+
if lint_violation:
2484+
found_error = True
24832485
all_violations.extend(violations)
24842486

24852487
if raise_on_error and found_error:

sqlmesh/lsp/main.py

Lines changed: 140 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4+
from collections import defaultdict
45
import logging
56
import typing as t
67
from pathlib import Path
78

89
from lsprotocol import types
910
from pygls.server import LanguageServer
10-
from pygls.workspace import TextDocument
1111

1212
from sqlmesh._version import __version__
1313
from sqlmesh.core.context import Context
14+
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
15+
16+
17+
class LSPContext:
18+
"""
19+
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20+
"""
21+
22+
def __init__(self, context: Context) -> None:
23+
self.context = context
24+
map: t.Dict[str, t.List[str]] = defaultdict(list[str])
25+
for model in context.models.values():
26+
if model._path is None:
27+
path = Path(model._path).resolve()
28+
map[f"file://{path.as_posix()}"].append(model.name)
29+
30+
self.map = map
1431

1532

1633
class SQLMeshLanguageServer:
@@ -27,29 +44,87 @@ def __init__(
2744
"""
2845
self.server = LanguageServer(server_name, version)
2946
self.context_class = context_class
30-
self.context: t.Optional[Context] = None
47+
self.context: t.Optional[LSPContext] = None
48+
self.lint_cache: t.Dict[str, t.List[AnnotatedRuleViolation]] = {}
3149

3250
# Register LSP features (e.g., formatting, hover, etc.)
3351
self._register_features()
3452

3553
def _register_features(self) -> None:
3654
"""Register LSP features on the internal LanguageServer instance."""
3755

56+
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
57+
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
58+
context = self._context_get_or_load(params.text_document.uri)
59+
if self.lint_cache.get(params.text_document.uri) is not None:
60+
ls.publish_diagnostics(
61+
params.text_document.uri,
62+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
63+
self.lint_cache[params.text_document.uri]
64+
),
65+
)
66+
return
67+
models = context.map[params.text_document.uri]
68+
if models is None:
69+
return
70+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
71+
models,
72+
raise_on_error=False,
73+
)
74+
ls.publish_diagnostics(
75+
params.text_document.uri,
76+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
77+
self.lint_cache[params.text_document.uri]
78+
),
79+
)
80+
81+
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
82+
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
83+
context = self._context_get_or_load(params.text_document.uri)
84+
models = context.map[params.text_document.uri]
85+
if models is None:
86+
return
87+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
88+
models,
89+
raise_on_error=False,
90+
)
91+
ls.publish_diagnostics(
92+
params.text_document.uri,
93+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
94+
self.lint_cache[params.text_document.uri]
95+
),
96+
)
97+
98+
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
99+
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
100+
context = self._context_get_or_load(params.text_document.uri)
101+
models = context.map[params.text_document.uri]
102+
if models is None:
103+
return
104+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
105+
models,
106+
raise_on_error=False,
107+
)
108+
ls.publish_diagnostics(
109+
params.text_document.uri,
110+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
111+
self.lint_cache[params.text_document.uri]
112+
),
113+
)
114+
38115
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
39116
def formatting(
40117
ls: LanguageServer, params: types.DocumentFormattingParams
41118
) -> t.List[types.TextEdit]:
42119
"""Format the document using SQLMesh `format_model_expressions`."""
43120
try:
44-
document = self.ensure_context_for_document(
45-
ls.workspace.get_document(params.text_document.uri)
46-
)
47-
121+
self._ensure_context_for_document(params.text_document.uri)
122+
document = ls.workspace.get_document(params.text_document.uri)
48123
if self.context is None:
49124
raise RuntimeError(f"No context found for document: {document.path}")
50125

51126
# Perform formatting using the loaded context
52-
self.context.format(paths=(Path(document.path),))
127+
self.context.context.format(paths=(Path(document.path),))
53128
with open(document.path, "r+", encoding="utf-8") as file:
54129
new_text = file.read()
55130

@@ -70,20 +145,31 @@ def formatting(
70145
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
71146
return []
72147

73-
def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
148+
def _context_get_or_load(self, document_uri: str) -> LSPContext:
149+
if self.context is None:
150+
self._ensure_context_for_document(document_uri)
151+
if self.context is None:
152+
raise RuntimeError("No context found")
153+
return self.context
154+
155+
def _ensure_context_for_document(
156+
self,
157+
document_uri: str,
158+
) -> None:
74159
"""
75160
Ensure that a context exists for the given document if applicable by searching
76161
for a config.py or config.yml file in the parent directories.
77162
"""
78-
# If the context is already loaded, check if this document belongs to it.
79163
if self.context is not None:
80-
self.context.load() # Reload or refresh context
81-
return document
164+
context = self.context
165+
context.context.load() # Reload or refresh context
166+
self.context = LSPContext(context.context)
167+
return
82168

83169
# No context yet: try to find config and load it
84-
path = Path(document.path).resolve()
170+
path = Path(self._uri_to_path(document_uri)).resolve()
85171
if path.suffix not in (".sql", ".py"):
86-
return document
172+
return
87173

88174
loaded = False
89175
# Ascend directories to look for config
@@ -93,18 +179,55 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
93179
if config_path.exists():
94180
try:
95181
# Use user-provided instantiator to build the context
96-
self.context = self.context_class(paths=[path])
97-
self.server.show_message(f"Context loaded for: {path}")
182+
created_context = self.context_class(paths=[path])
183+
self.context = LSPContext(created_context)
98184
loaded = True
99185
# Re-check context for document now that it's loaded
100-
return self.ensure_context_for_document(document)
186+
return self._ensure_context_for_document(document_uri)
101187
except Exception as e:
102188
self.server.show_message(
103189
f"Error loading context: {e}", types.MessageType.Error
104190
)
105191
path = path.parent
106192

107-
return document
193+
return
194+
195+
@staticmethod
196+
def _diagnostic_to_lsp_diagnostic(
197+
diagnostic: AnnotatedRuleViolation,
198+
) -> t.Optional[types.Diagnostic]:
199+
if diagnostic.model._path is None:
200+
return None
201+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
202+
lines = file.readlines()
203+
return types.Diagnostic(
204+
range=types.Range(
205+
start=types.Position(line=0, character=0),
206+
end=types.Position(line=len(lines), character=len(lines[-1])),
207+
),
208+
message=diagnostic.violation_msg,
209+
severity=types.DiagnosticSeverity.Error
210+
if diagnostic.violation_type == "error"
211+
else types.DiagnosticSeverity.Warning,
212+
)
213+
214+
@staticmethod
215+
def _diagnostics_to_lsp_diagnostics(
216+
diagnostics: t.List[AnnotatedRuleViolation],
217+
) -> t.List[types.Diagnostic]:
218+
lsp_diagnostics: t.List[types.Diagnostic] = []
219+
for diagnostic in diagnostics:
220+
lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic(diagnostic)
221+
if lsp_diagnostic is not None:
222+
lsp_diagnostics.append(lsp_diagnostic)
223+
return lsp_diagnostics
224+
225+
@staticmethod
226+
def _uri_to_path(uri: str) -> str:
227+
"""Convert a URI to a path."""
228+
if uri.startswith("file://"):
229+
return Path(uri[7:]).resolve().as_posix()
230+
return Path(uri).resolve().as_posix()
108231

109232
def start(self) -> None:
110233
"""Start the server with I/O transport."""

0 commit comments

Comments
 (0)