Skip to content

Commit a25c53f

Browse files
committed
chore(vscode): introduce pull first diagnostics
1 parent f4fa53f commit a25c53f

1 file changed

Lines changed: 174 additions & 33 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 174 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def __init__(
5757
self.server = LanguageServer(server_name, version)
5858
self.context_class = context_class
5959
self.lsp_context: t.Optional[LSPContext] = None
60-
self.lint_cache: t.Dict[URI, t.List[AnnotatedRuleViolation]] = {}
60+
# Cache stores tuples of (diagnostics, result_id)
61+
self.lint_cache: t.Dict[URI, t.Tuple[t.List[AnnotatedRuleViolation], str]] = {}
62+
self.client_supports_pull_diagnostics = False
63+
self._diagnostic_version_counter = 0
6164

6265
# Register LSP features (e.g., formatting, hover, etc.)
6366
self._register_features()
@@ -69,6 +72,18 @@ def _register_features(self) -> None:
6972
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
7073
"""Initialize the server when the client connects."""
7174
try:
75+
# Check if client supports pull diagnostics
76+
if params.capabilities and params.capabilities.text_document:
77+
diagnostics = getattr(params.capabilities.text_document, 'diagnostic', None)
78+
if diagnostics:
79+
self.client_supports_pull_diagnostics = True
80+
ls.log_trace("Client supports pull diagnostics")
81+
else:
82+
self.client_supports_pull_diagnostics = False
83+
ls.log_trace("Client does not support pull diagnostics")
84+
else:
85+
self.client_supports_pull_diagnostics = False
86+
7287
if params.workspace_folders:
7388
# Try to find a SQLMesh config file in any workspace folder (only at the root level)
7489
for folder in params.workspace_folders:
@@ -153,61 +168,76 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153168
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
154169
uri = URI(params.text_document.uri)
155170
context = self._context_get_or_load(uri)
156-
if self.lint_cache.get(uri) is not None:
171+
172+
# Always update the cache
173+
models = context.map[uri.to_path()]
174+
if models is None or not isinstance(models, ModelTarget):
175+
return
176+
177+
if self.lint_cache.get(uri) is None:
178+
diagnostics = context.context.lint_models(
179+
models.names,
180+
raise_on_error=False,
181+
)
182+
self._diagnostic_version_counter += 1
183+
result_id = str(self._diagnostic_version_counter)
184+
self.lint_cache[uri] = (diagnostics, result_id)
185+
186+
# Only publish diagnostics if client doesn't support pull diagnostics
187+
if not self.client_supports_pull_diagnostics:
188+
diagnostics, _ = self.lint_cache[uri]
157189
ls.publish_diagnostics(
158190
params.text_document.uri,
159-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
191+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
160192
)
161-
return
162-
models = context.map[uri.to_path()]
163-
if models is None:
164-
return
165-
if not isinstance(models, ModelTarget):
166-
return
167-
self.lint_cache[uri] = context.context.lint_models(
168-
models.names,
169-
raise_on_error=False,
170-
)
171-
ls.publish_diagnostics(
172-
params.text_document.uri,
173-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
174-
)
175193

176194
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
177195
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
178196
uri = URI(params.text_document.uri)
179197
context = self._context_get_or_load(uri)
180198
models = context.map[uri.to_path()]
181-
if models is None:
182-
return
183-
if not isinstance(models, ModelTarget):
199+
if models is None or not isinstance(models, ModelTarget):
184200
return
185-
self.lint_cache[uri] = context.context.lint_models(
201+
202+
# Always update the cache
203+
diagnostics = context.context.lint_models(
186204
models.names,
187205
raise_on_error=False,
188206
)
189-
ls.publish_diagnostics(
190-
params.text_document.uri,
191-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
192-
)
207+
self._diagnostic_version_counter += 1
208+
result_id = str(self._diagnostic_version_counter)
209+
self.lint_cache[uri] = (diagnostics, result_id)
210+
211+
# Only publish diagnostics if client doesn't support pull diagnostics
212+
if not self.client_supports_pull_diagnostics:
213+
ls.publish_diagnostics(
214+
params.text_document.uri,
215+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
216+
)
193217

194218
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
195219
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
196220
uri = URI(params.text_document.uri)
197221
context = self._context_get_or_load(uri)
198222
models = context.map[uri.to_path()]
199-
if models is None:
223+
if models is None or not isinstance(models, ModelTarget):
200224
return
201-
if not isinstance(models, ModelTarget):
202-
return
203-
self.lint_cache[uri] = context.context.lint_models(
225+
226+
# Always update the cache
227+
diagnostics = context.context.lint_models(
204228
models.names,
205229
raise_on_error=False,
206230
)
207-
ls.publish_diagnostics(
208-
params.text_document.uri,
209-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
210-
)
231+
self._diagnostic_version_counter += 1
232+
result_id = str(self._diagnostic_version_counter)
233+
self.lint_cache[uri] = (diagnostics, result_id)
234+
235+
# Only publish diagnostics if client doesn't support pull diagnostics
236+
if not self.client_supports_pull_diagnostics:
237+
ls.publish_diagnostics(
238+
params.text_document.uri,
239+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
240+
)
211241

212242
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
213243
def formatting(
@@ -326,6 +356,117 @@ def goto_definition(
326356
except Exception as e:
327357
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
328358
return []
359+
360+
@self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC)
361+
def diagnostic(
362+
ls: LanguageServer, params: types.DocumentDiagnosticParams
363+
) -> types.DocumentDiagnosticReport:
364+
"""Handle diagnostic pull requests from the client."""
365+
try:
366+
uri = URI(params.text_document.uri)
367+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
368+
369+
# Check if client provided a previous result ID
370+
if hasattr(params, 'previous_result_id') and params.previous_result_id == result_id:
371+
# Return unchanged report if diagnostics haven't changed
372+
return types.RelatedUnchangedDocumentDiagnosticReport(
373+
kind=types.DocumentDiagnosticReportKind.Unchanged,
374+
result_id=result_id,
375+
)
376+
377+
return types.RelatedFullDocumentDiagnosticReport(
378+
kind=types.DocumentDiagnosticReportKind.Full,
379+
items=diagnostics,
380+
result_id=result_id,
381+
)
382+
except Exception as e:
383+
ls.show_message(f"Error getting diagnostics: {e}", types.MessageType.Error)
384+
return types.RelatedFullDocumentDiagnosticReport(
385+
kind=types.DocumentDiagnosticReportKind.Full,
386+
items=[],
387+
)
388+
389+
@self.server.feature(types.WORKSPACE_DIAGNOSTIC)
390+
def workspace_diagnostic(
391+
ls: LanguageServer, params: types.WorkspaceDiagnosticParams
392+
) -> types.WorkspaceDiagnosticReport:
393+
"""Handle workspace-wide diagnostic pull requests from the client."""
394+
try:
395+
if self.lsp_context is None:
396+
current_path = Path.cwd()
397+
self._ensure_context_in_folder(current_path)
398+
399+
if self.lsp_context is None:
400+
return types.WorkspaceDiagnosticReport(items=[])
401+
402+
items = []
403+
404+
# Get all SQL and Python model files from the context
405+
for path, target in self.lsp_context.map.items():
406+
if isinstance(target, ModelTarget):
407+
uri = URI.from_path(path)
408+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
409+
410+
# Check if we have a previous result ID for this file
411+
previous_result_id = None
412+
if hasattr(params, 'previous_result_ids') and params.previous_result_ids:
413+
for prev in params.previous_result_ids:
414+
if prev.uri == uri.value:
415+
previous_result_id = prev.value
416+
break
417+
418+
if previous_result_id and previous_result_id == result_id:
419+
# File hasn't changed
420+
items.append(
421+
types.WorkspaceUnchangedDocumentDiagnosticReport(
422+
kind=types.DocumentDiagnosticReportKind.Unchanged,
423+
result_id=result_id,
424+
uri=uri.value,
425+
)
426+
)
427+
else:
428+
# File has changed or is new
429+
items.append(
430+
types.WorkspaceFullDocumentDiagnosticReport(
431+
kind=types.DocumentDiagnosticReportKind.Full,
432+
result_id=result_id,
433+
uri=uri.value,
434+
items=diagnostics,
435+
)
436+
)
437+
438+
return types.WorkspaceDiagnosticReport(items=items)
439+
440+
except Exception as e:
441+
ls.show_message(f"Error getting workspace diagnostics: {e}", types.MessageType.Error)
442+
return types.WorkspaceDiagnosticReport(items=[])
443+
444+
def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], str]:
445+
"""Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
446+
# Check if we have cached diagnostics
447+
if uri in self.lint_cache:
448+
diagnostics, result_id = self.lint_cache[uri]
449+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
450+
451+
# Try to get diagnostics by loading context and linting
452+
try:
453+
context = self._context_get_or_load(uri)
454+
models = context.map[uri.to_path()]
455+
if models is None or not isinstance(models, ModelTarget):
456+
return [], "0"
457+
458+
# Lint the models and cache the results
459+
diagnostics = context.context.lint_models(
460+
models.names,
461+
raise_on_error=False,
462+
)
463+
self._diagnostic_version_counter += 1
464+
result_id = str(self._diagnostic_version_counter)
465+
self.lint_cache[uri] = (diagnostics, result_id)
466+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
467+
except Exception:
468+
# If we can't get diagnostics, return empty list with no result ID
469+
return [], "0"
329470

330471
def _context_get_or_load(self, document_uri: URI) -> LSPContext:
331472
if self.lsp_context is None:

0 commit comments

Comments
 (0)