Skip to content

Commit 94b36ce

Browse files
committed
temp linting [ci skip]
1 parent 00a43e6 commit 94b36ce

1 file changed

Lines changed: 143 additions & 15 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 143 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,33 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4+
from collections import defaultdict
45
import logging
56
import typing as t
67
from pathlib import Path
78

89
from lsprotocol import types
910
from pygls.server import LanguageServer
10-
from pygls.workspace import TextDocument
1111

1212
from sqlmesh._version import __version__
1313
from sqlmesh.core.context import Context
14+
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
15+
16+
17+
class LSPContext:
18+
"""
19+
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20+
"""
21+
22+
def __init__(self, context: Context) -> None:
23+
self.context = context
24+
map: t.Dict[str, t.List[str]] = defaultdict(list[str])
25+
for model in context.models.values():
26+
if model._path is None:
27+
path = Path(model._path).resolve()
28+
map[f"file://{path.as_posix()}"].append(model.name)
29+
30+
self.map = map
1431

1532

1633
class SQLMeshLanguageServer:
@@ -27,29 +44,88 @@ def __init__(
2744
"""
2845
self.server = LanguageServer(server_name, version)
2946
self.context_class = context_class
30-
self.context: t.Optional[Context] = None
47+
# A tuple of (context, reverse_map) where the reverse_map is uri to model name
48+
self.context: t.Optional[LSPContext] = None
49+
self.lint_cache: t.Dict[str, t.List[AnnotatedRuleViolation]] = {}
3150

3251
# Register LSP features (e.g., formatting, hover, etc.)
3352
self._register_features()
3453

3554
def _register_features(self) -> None:
3655
"""Register LSP features on the internal LanguageServer instance."""
3756

57+
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
58+
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
59+
context = self._context_get_or_load(params.text_document.uri)
60+
if self.lint_cache.get(params.text_document.uri) is not None:
61+
ls.publish_diagnostics(
62+
params.text_document.uri,
63+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
64+
self.lint_cache[params.text_document.uri]
65+
),
66+
)
67+
return
68+
models = context.map[params.text_document.uri]
69+
if models is None:
70+
return
71+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
72+
models,
73+
raise_on_error=False,
74+
)
75+
ls.publish_diagnostics(
76+
params.text_document.uri,
77+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
78+
self.lint_cache[params.text_document.uri]
79+
),
80+
)
81+
82+
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
83+
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
84+
context = self._context_get_or_load(params.text_document.uri)
85+
models = context.map[params.text_document.uri]
86+
if models is None:
87+
return
88+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
89+
models,
90+
raise_on_error=False,
91+
)
92+
ls.publish_diagnostics(
93+
params.text_document.uri,
94+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
95+
self.lint_cache[params.text_document.uri]
96+
),
97+
)
98+
99+
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
100+
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
101+
context = self._context_get_or_load(params.text_document.uri)
102+
models = context.map[params.text_document.uri]
103+
if models is None:
104+
return
105+
self.lint_cache[params.text_document.uri] = context.context.lint_models(
106+
models,
107+
raise_on_error=False,
108+
)
109+
ls.publish_diagnostics(
110+
params.text_document.uri,
111+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
112+
self.lint_cache[params.text_document.uri]
113+
),
114+
)
115+
38116
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
39117
def formatting(
40118
ls: LanguageServer, params: types.DocumentFormattingParams
41119
) -> t.List[types.TextEdit]:
42120
"""Format the document using SQLMesh `format_model_expressions`."""
43121
try:
44-
document = self.ensure_context_for_document(
45-
ls.workspace.get_document(params.text_document.uri)
46-
)
47-
122+
self._ensure_context_for_document(params.text_document.uri)
123+
document = ls.workspace.get_document(params.text_document.uri)
48124
if self.context is None:
49125
raise RuntimeError(f"No context found for document: {document.path}")
50126

51127
# Perform formatting using the loaded context
52-
self.context.format(paths=(Path(document.path),))
128+
self.context.context.format(paths=(Path(document.path),))
53129
with open(document.path, "r+", encoding="utf-8") as file:
54130
new_text = file.read()
55131

@@ -70,20 +146,32 @@ def formatting(
70146
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
71147
return []
72148

73-
def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
149+
def _context_get_or_load(self, document_uri: str) -> LSPContext:
150+
if self.context is None:
151+
self._ensure_context_for_document(document_uri)
152+
if self.context is None:
153+
raise RuntimeError("No context found")
154+
return self.context
155+
156+
def _ensure_context_for_document(
157+
self,
158+
document_uri: str,
159+
) -> None:
74160
"""
75161
Ensure that a context exists for the given document if applicable by searching
76162
for a config.py or config.yml file in the parent directories.
77163
"""
78164
# If the context is already loaded, check if this document belongs to it.
79165
if self.context is not None:
80-
self.context.load() # Reload or refresh context
81-
return document
166+
context = self.context
167+
context.context.load() # Reload or refresh context
168+
self.context = LSPContext(context.context)
169+
return
82170

83171
# No context yet: try to find config and load it
84-
path = Path(document.path).resolve()
172+
path = Path(self._uri_to_path(document_uri)).resolve()
85173
if path.suffix not in (".sql", ".py"):
86-
return document
174+
return
87175

88176
loaded = False
89177
# Ascend directories to look for config
@@ -93,18 +181,58 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
93181
if config_path.exists():
94182
try:
95183
# Use user-provided instantiator to build the context
96-
self.context = self.context_class(paths=[path])
184+
created_context = self.context_class(paths=[path])
185+
self.context = LSPContext(created_context)
97186
self.server.show_message(f"Context loaded for: {path}")
98187
loaded = True
99188
# Re-check context for document now that it's loaded
100-
return self.ensure_context_for_document(document)
189+
return self._ensure_context_for_document(document_uri)
101190
except Exception as e:
102191
self.server.show_message(
103192
f"Error loading context: {e}", types.MessageType.Error
104193
)
105194
path = path.parent
106195

107-
return document
196+
return
197+
198+
@staticmethod
199+
def _diagnostic_to_lsp_diagnostic(
200+
diagnostic: AnnotatedRuleViolation,
201+
) -> t.Optional[types.Diagnostic]:
202+
if diagnostic.model._path is None:
203+
return None
204+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
205+
lines = file.readlines()
206+
return types.Diagnostic(
207+
range=types.Range(
208+
start=types.Position(line=0, character=0),
209+
end=types.Position(line=len(lines), character=len(lines[-1])),
210+
),
211+
message=diagnostic.violation_msg,
212+
severity=types.DiagnosticSeverity.Error
213+
if diagnostic.violation_type == "error"
214+
else types.DiagnosticSeverity.Warning,
215+
)
216+
217+
@staticmethod
218+
def _diagnostics_to_lsp_diagnostics(
219+
diagnostics: t.List[AnnotatedRuleViolation],
220+
) -> t.List[types.Diagnostic]:
221+
lsp_diagnostics: t.List[types.Diagnostic] = []
222+
for diagnostic in diagnostics:
223+
if diagnostic is None:
224+
continue
225+
lsp_diagnostic = SQLMeshLanguageServer._diagnostic_to_lsp_diagnostic(diagnostic)
226+
if lsp_diagnostic is not None:
227+
lsp_diagnostics.append(lsp_diagnostic)
228+
return lsp_diagnostics
229+
230+
@staticmethod
231+
def _uri_to_path(uri: str) -> str:
232+
"""Convert a URI to a path."""
233+
if uri.startswith("file://"):
234+
return Path(uri[7:]).resolve().as_posix()
235+
return Path(uri).resolve().as_posix()
108236

109237
def start(self) -> None:
110238
"""Start the server with I/O transport."""

0 commit comments

Comments
 (0)