Skip to content

Commit 67c7dea

Browse files
committed
chore(vscode): introduce pull first diagnostics
1 parent 9e07751 commit 67c7dea

1 file changed

Lines changed: 181 additions & 33 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 181 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def __init__(
5757
self.server = LanguageServer(server_name, version)
5858
self.context_class = context_class
5959
self.lsp_context: t.Optional[LSPContext] = None
60-
self.lint_cache: t.Dict[URI, t.List[AnnotatedRuleViolation]] = {}
60+
# Cache stores tuples of (diagnostics, result_id)
61+
self.lint_cache: t.Dict[URI, t.Tuple[t.List[AnnotatedRuleViolation], str]] = {}
62+
self.client_supports_pull_diagnostics = False
63+
self._diagnostic_version_counter = 0
6164

6265
# Register LSP features (e.g., formatting, hover, etc.)
6366
self._register_features()
@@ -69,6 +72,18 @@ def _register_features(self) -> None:
6972
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
7073
"""Initialize the server when the client connects."""
7174
try:
75+
# Check if client supports pull diagnostics
76+
if params.capabilities and params.capabilities.text_document:
77+
diagnostics = getattr(params.capabilities.text_document, "diagnostic", None)
78+
if diagnostics:
79+
self.client_supports_pull_diagnostics = True
80+
ls.log_trace("Client supports pull diagnostics")
81+
else:
82+
self.client_supports_pull_diagnostics = False
83+
ls.log_trace("Client does not support pull diagnostics")
84+
else:
85+
self.client_supports_pull_diagnostics = False
86+
7287
if params.workspace_folders:
7388
# Try to find a SQLMesh config file in any workspace folder (only at the root level)
7489
for folder in params.workspace_folders:
@@ -153,61 +168,76 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153168
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
154169
uri = URI(params.text_document.uri)
155170
context = self._context_get_or_load(uri)
156-
if self.lint_cache.get(uri) is not None:
171+
172+
# Always update the cache
173+
models = context.map[uri.to_path()]
174+
if models is None or not isinstance(models, ModelTarget):
175+
return
176+
177+
if self.lint_cache.get(uri) is None:
178+
diagnostics = context.context.lint_models(
179+
models.names,
180+
raise_on_error=False,
181+
)
182+
self._diagnostic_version_counter += 1
183+
result_id = str(self._diagnostic_version_counter)
184+
self.lint_cache[uri] = (diagnostics, result_id)
185+
186+
# Only publish diagnostics if client doesn't support pull diagnostics
187+
if not self.client_supports_pull_diagnostics:
188+
diagnostics, _ = self.lint_cache[uri]
157189
ls.publish_diagnostics(
158190
params.text_document.uri,
159-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
191+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
160192
)
161-
return
162-
models = context.map[uri.to_path()]
163-
if models is None:
164-
return
165-
if not isinstance(models, ModelTarget):
166-
return
167-
self.lint_cache[uri] = context.context.lint_models(
168-
models.names,
169-
raise_on_error=False,
170-
)
171-
ls.publish_diagnostics(
172-
params.text_document.uri,
173-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
174-
)
175193

176194
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
177195
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
178196
uri = URI(params.text_document.uri)
179197
context = self._context_get_or_load(uri)
180198
models = context.map[uri.to_path()]
181-
if models is None:
199+
if models is None or not isinstance(models, ModelTarget):
182200
return
183-
if not isinstance(models, ModelTarget):
184-
return
185-
self.lint_cache[uri] = context.context.lint_models(
201+
202+
# Always update the cache
203+
diagnostics = context.context.lint_models(
186204
models.names,
187205
raise_on_error=False,
188206
)
189-
ls.publish_diagnostics(
190-
params.text_document.uri,
191-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
192-
)
207+
self._diagnostic_version_counter += 1
208+
result_id = str(self._diagnostic_version_counter)
209+
self.lint_cache[uri] = (diagnostics, result_id)
210+
211+
# Only publish diagnostics if client doesn't support pull diagnostics
212+
if not self.client_supports_pull_diagnostics:
213+
ls.publish_diagnostics(
214+
params.text_document.uri,
215+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
216+
)
193217

194218
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
195219
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
196220
uri = URI(params.text_document.uri)
197221
context = self._context_get_or_load(uri)
198222
models = context.map[uri.to_path()]
199-
if models is None:
200-
return
201-
if not isinstance(models, ModelTarget):
223+
if models is None or not isinstance(models, ModelTarget):
202224
return
203-
self.lint_cache[uri] = context.context.lint_models(
225+
226+
# Always update the cache
227+
diagnostics = context.context.lint_models(
204228
models.names,
205229
raise_on_error=False,
206230
)
207-
ls.publish_diagnostics(
208-
params.text_document.uri,
209-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
210-
)
231+
self._diagnostic_version_counter += 1
232+
result_id = str(self._diagnostic_version_counter)
233+
self.lint_cache[uri] = (diagnostics, result_id)
234+
235+
# Only publish diagnostics if client doesn't support pull diagnostics
236+
if not self.client_supports_pull_diagnostics:
237+
ls.publish_diagnostics(
238+
params.text_document.uri,
239+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
240+
)
211241

212242
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
213243
def formatting(
@@ -327,6 +357,124 @@ def goto_definition(
327357
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
328358
return []
329359

360+
@self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC)
361+
def diagnostic(
362+
ls: LanguageServer, params: types.DocumentDiagnosticParams
363+
) -> types.DocumentDiagnosticReport:
364+
"""Handle diagnostic pull requests from the client."""
365+
try:
366+
uri = URI(params.text_document.uri)
367+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
368+
369+
# Check if client provided a previous result ID
370+
if hasattr(params, "previous_result_id") and params.previous_result_id == result_id:
371+
# Return unchanged report if diagnostics haven't changed
372+
return types.RelatedUnchangedDocumentDiagnosticReport(
373+
kind=types.DocumentDiagnosticReportKind.Unchanged,
374+
result_id=result_id,
375+
)
376+
377+
return types.RelatedFullDocumentDiagnosticReport(
378+
kind=types.DocumentDiagnosticReportKind.Full,
379+
items=diagnostics,
380+
result_id=result_id,
381+
)
382+
except Exception as e:
383+
ls.show_message(f"Error getting diagnostics: {e}", types.MessageType.Error)
384+
return types.RelatedFullDocumentDiagnosticReport(
385+
kind=types.DocumentDiagnosticReportKind.Full,
386+
items=[],
387+
)
388+
389+
@self.server.feature(types.WORKSPACE_DIAGNOSTIC)
390+
def workspace_diagnostic(
391+
ls: LanguageServer, params: types.WorkspaceDiagnosticParams
392+
) -> types.WorkspaceDiagnosticReport:
393+
"""Handle workspace-wide diagnostic pull requests from the client."""
394+
try:
395+
if self.lsp_context is None:
396+
current_path = Path.cwd()
397+
self._ensure_context_in_folder(current_path)
398+
399+
if self.lsp_context is None:
400+
return types.WorkspaceDiagnosticReport(items=[])
401+
402+
items: t.List[
403+
t.Union[
404+
types.WorkspaceFullDocumentDiagnosticReport,
405+
types.WorkspaceUnchangedDocumentDiagnosticReport,
406+
]
407+
] = []
408+
409+
# Get all SQL and Python model files from the context
410+
for path, target in self.lsp_context.map.items():
411+
if isinstance(target, ModelTarget):
412+
uri = URI.from_path(path)
413+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
414+
415+
# Check if we have a previous result ID for this file
416+
previous_result_id = None
417+
if hasattr(params, "previous_result_ids") and params.previous_result_ids:
418+
for prev in params.previous_result_ids:
419+
if prev.uri == uri.value:
420+
previous_result_id = prev.value
421+
break
422+
423+
if previous_result_id and previous_result_id == result_id:
424+
# File hasn't changed
425+
items.append(
426+
types.WorkspaceUnchangedDocumentDiagnosticReport(
427+
kind=types.DocumentDiagnosticReportKind.Unchanged,
428+
result_id=result_id,
429+
uri=uri.value,
430+
)
431+
)
432+
else:
433+
# File has changed or is new
434+
items.append(
435+
types.WorkspaceFullDocumentDiagnosticReport(
436+
kind=types.DocumentDiagnosticReportKind.Full,
437+
result_id=result_id,
438+
uri=uri.value,
439+
items=diagnostics,
440+
)
441+
)
442+
443+
return types.WorkspaceDiagnosticReport(items=items)
444+
445+
except Exception as e:
446+
ls.show_message(
447+
f"Error getting workspace diagnostics: {e}", types.MessageType.Error
448+
)
449+
return types.WorkspaceDiagnosticReport(items=[])
450+
451+
def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], str]:
452+
"""Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
453+
# Check if we have cached diagnostics
454+
if uri in self.lint_cache:
455+
diagnostics, result_id = self.lint_cache[uri]
456+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
457+
458+
# Try to get diagnostics by loading context and linting
459+
try:
460+
context = self._context_get_or_load(uri)
461+
models = context.map[uri.to_path()]
462+
if models is None or not isinstance(models, ModelTarget):
463+
return [], "0"
464+
465+
# Lint the models and cache the results
466+
diagnostics = context.context.lint_models(
467+
models.names,
468+
raise_on_error=False,
469+
)
470+
self._diagnostic_version_counter += 1
471+
result_id = str(self._diagnostic_version_counter)
472+
self.lint_cache[uri] = (diagnostics, result_id)
473+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
474+
except Exception:
475+
# If we can't get diagnostics, return empty list with no result ID
476+
return [], "0"
477+
330478
def _context_get_or_load(self, document_uri: URI) -> LSPContext:
331479
if self.lsp_context is None:
332480
self._ensure_context_for_document(document_uri)

0 commit comments

Comments
 (0)