Skip to content

Commit 1fb40ee

Browse files
authored
chore(vscode): introduce pull first diagnostics (#4565)
1 parent 2785fc9 commit 1fb40ee

2 files changed

Lines changed: 221 additions & 33 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 178 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,12 @@ def __init__(
5757
self.server = LanguageServer(server_name, version)
5858
self.context_class = context_class
5959
self.lsp_context: t.Optional[LSPContext] = None
60-
self.lint_cache: t.Dict[URI, t.List[AnnotatedRuleViolation]] = {}
6160

61+
# Cache stores tuples of (diagnostics, diagnostic_version)
62+
self.lint_cache: t.Dict[URI, t.Tuple[t.List[AnnotatedRuleViolation], int]] = {}
63+
self._diagnostic_version_counter: int = 0
64+
65+
self.client_supports_pull_diagnostics = False
6266
# Register LSP features (e.g., formatting, hover, etc.)
6367
self._register_features()
6468

@@ -69,6 +73,18 @@ def _register_features(self) -> None:
6973
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
7074
"""Initialize the server when the client connects."""
7175
try:
76+
# Check if client supports pull diagnostics
77+
if params.capabilities and params.capabilities.text_document:
78+
diagnostics = getattr(params.capabilities.text_document, "diagnostic", None)
79+
if diagnostics:
80+
self.client_supports_pull_diagnostics = True
81+
ls.log_trace("Client supports pull diagnostics")
82+
else:
83+
self.client_supports_pull_diagnostics = False
84+
ls.log_trace("Client does not support pull diagnostics")
85+
else:
86+
self.client_supports_pull_diagnostics = False
87+
7288
if params.workspace_folders:
7389
# Try to find a SQLMesh config file in any workspace folder (only at the root level)
7490
for folder in params.workspace_folders:
@@ -153,61 +169,71 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153169
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
154170
uri = URI(params.text_document.uri)
155171
context = self._context_get_or_load(uri)
156-
if self.lint_cache.get(uri) is not None:
172+
models = context.map[uri.to_path()]
173+
if models is None or not isinstance(models, ModelTarget):
174+
return
175+
176+
if self.lint_cache.get(uri) is None:
177+
diagnostics = context.context.lint_models(
178+
models.names,
179+
raise_on_error=False,
180+
)
181+
self._diagnostic_version_counter += 1
182+
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
183+
184+
# Only publish diagnostics if client doesn't support pull diagnostics
185+
if not self.client_supports_pull_diagnostics:
186+
diagnostics, _ = self.lint_cache[uri]
157187
ls.publish_diagnostics(
158188
params.text_document.uri,
159-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
189+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
160190
)
161-
return
162-
models = context.map[uri.to_path()]
163-
if models is None:
164-
return
165-
if not isinstance(models, ModelTarget):
166-
return
167-
self.lint_cache[uri] = context.context.lint_models(
168-
models.names,
169-
raise_on_error=False,
170-
)
171-
ls.publish_diagnostics(
172-
params.text_document.uri,
173-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
174-
)
175191

176192
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
177193
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
178194
uri = URI(params.text_document.uri)
179195
context = self._context_get_or_load(uri)
180196
models = context.map[uri.to_path()]
181-
if models is None:
197+
if models is None or not isinstance(models, ModelTarget):
182198
return
183-
if not isinstance(models, ModelTarget):
184-
return
185-
self.lint_cache[uri] = context.context.lint_models(
199+
200+
# Always update the cache
201+
diagnostics = context.context.lint_models(
186202
models.names,
187203
raise_on_error=False,
188204
)
189-
ls.publish_diagnostics(
190-
params.text_document.uri,
191-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
192-
)
205+
self._diagnostic_version_counter += 1
206+
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
207+
208+
# Only publish diagnostics if client doesn't support pull diagnostics
209+
if not self.client_supports_pull_diagnostics:
210+
ls.publish_diagnostics(
211+
params.text_document.uri,
212+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
213+
)
193214

194215
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
195216
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
196217
uri = URI(params.text_document.uri)
197218
context = self._context_get_or_load(uri)
198219
models = context.map[uri.to_path()]
199-
if models is None:
200-
return
201-
if not isinstance(models, ModelTarget):
220+
if models is None or not isinstance(models, ModelTarget):
202221
return
203-
self.lint_cache[uri] = context.context.lint_models(
222+
223+
# Always update the cache
224+
diagnostics = context.context.lint_models(
204225
models.names,
205226
raise_on_error=False,
206227
)
207-
ls.publish_diagnostics(
208-
params.text_document.uri,
209-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
210-
)
228+
self._diagnostic_version_counter += 1
229+
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
230+
231+
# Only publish diagnostics if client doesn't support pull diagnostics
232+
if not self.client_supports_pull_diagnostics:
233+
ls.publish_diagnostics(
234+
params.text_document.uri,
235+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
236+
)
211237

212238
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
213239
def formatting(
@@ -327,6 +353,125 @@ def goto_definition(
327353
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
328354
return []
329355

356+
@self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC)
357+
def diagnostic(
358+
ls: LanguageServer, params: types.DocumentDiagnosticParams
359+
) -> types.DocumentDiagnosticReport:
360+
"""Handle diagnostic pull requests from the client."""
361+
try:
362+
uri = URI(params.text_document.uri)
363+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
364+
365+
# Check if client provided a previous result ID
366+
if hasattr(params, "previous_result_id") and params.previous_result_id == result_id:
367+
# Return unchanged report if diagnostics haven't changed
368+
return types.RelatedUnchangedDocumentDiagnosticReport(
369+
kind=types.DocumentDiagnosticReportKind.Unchanged,
370+
result_id=str(result_id),
371+
)
372+
373+
return types.RelatedFullDocumentDiagnosticReport(
374+
kind=types.DocumentDiagnosticReportKind.Full,
375+
items=diagnostics,
376+
result_id=str(result_id),
377+
)
378+
except Exception as e:
379+
ls.show_message(f"Error getting diagnostics: {e}", types.MessageType.Error)
380+
return types.RelatedFullDocumentDiagnosticReport(
381+
kind=types.DocumentDiagnosticReportKind.Full,
382+
items=[],
383+
)
384+
385+
@self.server.feature(types.WORKSPACE_DIAGNOSTIC)
386+
def workspace_diagnostic(
387+
ls: LanguageServer, params: types.WorkspaceDiagnosticParams
388+
) -> types.WorkspaceDiagnosticReport:
389+
"""Handle workspace-wide diagnostic pull requests from the client."""
390+
try:
391+
if self.lsp_context is None:
392+
current_path = Path.cwd()
393+
self._ensure_context_in_folder(current_path)
394+
395+
if self.lsp_context is None:
396+
return types.WorkspaceDiagnosticReport(items=[])
397+
398+
items: t.List[
399+
t.Union[
400+
types.WorkspaceFullDocumentDiagnosticReport,
401+
types.WorkspaceUnchangedDocumentDiagnosticReport,
402+
]
403+
] = []
404+
405+
# Get all SQL and Python model files from the context
406+
for path, target in self.lsp_context.map.items():
407+
if isinstance(target, ModelTarget):
408+
uri = URI.from_path(path)
409+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
410+
411+
# Check if we have a previous result ID for this file
412+
previous_result_id = None
413+
if hasattr(params, "previous_result_ids") and params.previous_result_ids:
414+
for prev in params.previous_result_ids:
415+
if prev.uri == uri.value:
416+
previous_result_id = prev.value
417+
break
418+
419+
if previous_result_id and previous_result_id == result_id:
420+
# File hasn't changed
421+
items.append(
422+
types.WorkspaceUnchangedDocumentDiagnosticReport(
423+
kind=types.DocumentDiagnosticReportKind.Unchanged,
424+
result_id=str(result_id),
425+
uri=uri.value,
426+
)
427+
)
428+
else:
429+
# File has changed or is new
430+
items.append(
431+
types.WorkspaceFullDocumentDiagnosticReport(
432+
kind=types.DocumentDiagnosticReportKind.Full,
433+
result_id=str(result_id),
434+
uri=uri.value,
435+
items=diagnostics,
436+
)
437+
)
438+
439+
return types.WorkspaceDiagnosticReport(items=items)
440+
441+
except Exception as e:
442+
ls.show_message(
443+
f"Error getting workspace diagnostics: {e}", types.MessageType.Error
444+
)
445+
return types.WorkspaceDiagnosticReport(items=[])
446+
447+
def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], int]:
448+
"""Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
449+
# Check if we have cached diagnostics
450+
if uri in self.lint_cache:
451+
diagnostics, result_id = self.lint_cache[uri]
452+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
453+
454+
# Try to get diagnostics by loading context and linting
455+
try:
456+
context = self._context_get_or_load(uri)
457+
models = context.map[uri.to_path()]
458+
if models is None or not isinstance(models, ModelTarget):
459+
return [], 0
460+
461+
# Lint the models and cache the results
462+
diagnostics = context.context.lint_models(
463+
models.names,
464+
raise_on_error=False,
465+
)
466+
self._diagnostic_version_counter += 1
467+
self.lint_cache[uri] = (diagnostics, self._diagnostic_version_counter)
468+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(
469+
diagnostics
470+
), self._diagnostic_version_counter
471+
except Exception:
472+
# If we can't get diagnostics, return empty list with no result ID
473+
return [], 0
474+
330475
def _context_get_or_load(self, document_uri: URI) -> LSPContext:
331476
if self.lsp_context is None:
332477
self._ensure_context_for_document(document_uri)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import { test, expect } from '@playwright/test';
2+
import path from 'path';
3+
import fs from 'fs-extra';
4+
import os from 'os';
5+
import { startVSCode, SUSHI_SOURCE_PATH } from './utils';
6+
7+
test('Workspace diagnostics show up in the diagnostics panel', async () => {
8+
const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'vscode-test-sushi-'));
9+
await fs.copy(SUSHI_SOURCE_PATH, tempDir);
10+
11+
const configPath = path.join(tempDir, 'config.py');
12+
const configContent = await fs.readFile(configPath, 'utf8');
13+
const updatedContent = configContent.replace('enabled=False', 'enabled=True');
14+
await fs.writeFile(configPath, updatedContent);
15+
16+
try {
17+
const { window, close } = await startVSCode(tempDir);
18+
19+
// Wait for the models folder to be visible
20+
await window.waitForSelector('text=models');
21+
22+
// Click on the models folder, excluding external_models
23+
await window.getByRole('treeitem', { name: 'models', exact: true }).locator('a').click();
24+
25+
// Open the customer_revenue_lifetime model
26+
await window.getByRole('treeitem', { name: 'customers.sql', exact: true }).locator('a').click();
27+
28+
await
29+
30+
// Open problems panel
31+
await window.keyboard.press(process.platform === 'darwin' ? 'Meta+Shift+P' : 'Control+Shift+P');
32+
await window.keyboard.type('View: Focus Problems');
33+
await window.keyboard.press('Enter');
34+
35+
36+
await window.waitForSelector('text=problems');
37+
await window.waitForSelector("text=All models should have an owner");
38+
39+
await close();
40+
} finally {
41+
await fs.remove(tempDir);
42+
}
43+
});

0 commit comments

Comments
 (0)