Skip to content

Commit 1433c27

Browse files
committed
chore(vscode): introduce pull first diagnostics
1 parent 2785fc9 commit 1433c27

2 files changed

Lines changed: 222 additions & 33 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 179 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def __init__(
5757
self.server = LanguageServer(server_name, version)
5858
self.context_class = context_class
5959
self.lsp_context: t.Optional[LSPContext] = None
60-
self.lint_cache: t.Dict[URI, t.List[AnnotatedRuleViolation]] = {}
60+
# Cache stores tuples of (diagnostics, result_id)
61+
self.lint_cache: t.Dict[URI, t.Tuple[t.List[AnnotatedRuleViolation], str]] = {}
62+
self.client_supports_pull_diagnostics = False
63+
self._diagnostic_version_counter = 0
6164

6265
# Register LSP features (e.g., formatting, hover, etc.)
6366
self._register_features()
@@ -69,6 +72,18 @@ def _register_features(self) -> None:
6972
def initialize(ls: LanguageServer, params: types.InitializeParams) -> None:
7073
"""Initialize the server when the client connects."""
7174
try:
75+
# Check if client supports pull diagnostics
76+
if params.capabilities and params.capabilities.text_document:
77+
diagnostics = getattr(params.capabilities.text_document, "diagnostic", None)
78+
if diagnostics:
79+
self.client_supports_pull_diagnostics = True
80+
ls.log_trace("Client supports pull diagnostics")
81+
else:
82+
self.client_supports_pull_diagnostics = False
83+
ls.log_trace("Client does not support pull diagnostics")
84+
else:
85+
self.client_supports_pull_diagnostics = False
86+
7287
if params.workspace_folders:
7388
# Try to find a SQLMesh config file in any workspace folder (only at the root level)
7489
for folder in params.workspace_folders:
@@ -153,61 +168,74 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153168
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
154169
uri = URI(params.text_document.uri)
155170
context = self._context_get_or_load(uri)
156-
if self.lint_cache.get(uri) is not None:
171+
models = context.map[uri.to_path()]
172+
if models is None or not isinstance(models, ModelTarget):
173+
return
174+
175+
if self.lint_cache.get(uri) is None:
176+
diagnostics = context.context.lint_models(
177+
models.names,
178+
raise_on_error=False,
179+
)
180+
self._diagnostic_version_counter += 1
181+
result_id = str(self._diagnostic_version_counter)
182+
self.lint_cache[uri] = (diagnostics, result_id)
183+
184+
# Only publish diagnostics if client doesn't support pull diagnostics
185+
if not self.client_supports_pull_diagnostics:
186+
diagnostics, _ = self.lint_cache[uri]
157187
ls.publish_diagnostics(
158188
params.text_document.uri,
159-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
189+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
160190
)
161-
return
162-
models = context.map[uri.to_path()]
163-
if models is None:
164-
return
165-
if not isinstance(models, ModelTarget):
166-
return
167-
self.lint_cache[uri] = context.context.lint_models(
168-
models.names,
169-
raise_on_error=False,
170-
)
171-
ls.publish_diagnostics(
172-
params.text_document.uri,
173-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
174-
)
175191

176192
@self.server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
177193
def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) -> None:
178194
uri = URI(params.text_document.uri)
179195
context = self._context_get_or_load(uri)
180196
models = context.map[uri.to_path()]
181-
if models is None:
182-
return
183-
if not isinstance(models, ModelTarget):
197+
if models is None or not isinstance(models, ModelTarget):
184198
return
185-
self.lint_cache[uri] = context.context.lint_models(
199+
200+
# Always update the cache
201+
diagnostics = context.context.lint_models(
186202
models.names,
187203
raise_on_error=False,
188204
)
189-
ls.publish_diagnostics(
190-
params.text_document.uri,
191-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
192-
)
205+
self._diagnostic_version_counter += 1
206+
result_id = str(self._diagnostic_version_counter)
207+
self.lint_cache[uri] = (diagnostics, result_id)
208+
209+
# Only publish diagnostics if client doesn't support pull diagnostics
210+
if not self.client_supports_pull_diagnostics:
211+
ls.publish_diagnostics(
212+
params.text_document.uri,
213+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
214+
)
193215

194216
@self.server.feature(types.TEXT_DOCUMENT_DID_SAVE)
195217
def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None:
196218
uri = URI(params.text_document.uri)
197219
context = self._context_get_or_load(uri)
198220
models = context.map[uri.to_path()]
199-
if models is None:
200-
return
201-
if not isinstance(models, ModelTarget):
221+
if models is None or not isinstance(models, ModelTarget):
202222
return
203-
self.lint_cache[uri] = context.context.lint_models(
223+
224+
# Always update the cache
225+
diagnostics = context.context.lint_models(
204226
models.names,
205227
raise_on_error=False,
206228
)
207-
ls.publish_diagnostics(
208-
params.text_document.uri,
209-
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(self.lint_cache[uri]),
210-
)
229+
self._diagnostic_version_counter += 1
230+
result_id = str(self._diagnostic_version_counter)
231+
self.lint_cache[uri] = (diagnostics, result_id)
232+
233+
# Only publish diagnostics if client doesn't support pull diagnostics
234+
if not self.client_supports_pull_diagnostics:
235+
ls.publish_diagnostics(
236+
params.text_document.uri,
237+
SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics),
238+
)
211239

212240
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
213241
def formatting(
@@ -327,6 +355,124 @@ def goto_definition(
327355
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
328356
return []
329357

358+
@self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC)
359+
def diagnostic(
360+
ls: LanguageServer, params: types.DocumentDiagnosticParams
361+
) -> types.DocumentDiagnosticReport:
362+
"""Handle diagnostic pull requests from the client."""
363+
try:
364+
uri = URI(params.text_document.uri)
365+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
366+
367+
# Check if client provided a previous result ID
368+
if hasattr(params, "previous_result_id") and params.previous_result_id == result_id:
369+
# Return unchanged report if diagnostics haven't changed
370+
return types.RelatedUnchangedDocumentDiagnosticReport(
371+
kind=types.DocumentDiagnosticReportKind.Unchanged,
372+
result_id=result_id,
373+
)
374+
375+
return types.RelatedFullDocumentDiagnosticReport(
376+
kind=types.DocumentDiagnosticReportKind.Full,
377+
items=diagnostics,
378+
result_id=result_id,
379+
)
380+
except Exception as e:
381+
ls.show_message(f"Error getting diagnostics: {e}", types.MessageType.Error)
382+
return types.RelatedFullDocumentDiagnosticReport(
383+
kind=types.DocumentDiagnosticReportKind.Full,
384+
items=[],
385+
)
386+
387+
@self.server.feature(types.WORKSPACE_DIAGNOSTIC)
388+
def workspace_diagnostic(
389+
ls: LanguageServer, params: types.WorkspaceDiagnosticParams
390+
) -> types.WorkspaceDiagnosticReport:
391+
"""Handle workspace-wide diagnostic pull requests from the client."""
392+
try:
393+
if self.lsp_context is None:
394+
current_path = Path.cwd()
395+
self._ensure_context_in_folder(current_path)
396+
397+
if self.lsp_context is None:
398+
return types.WorkspaceDiagnosticReport(items=[])
399+
400+
items: t.List[
401+
t.Union[
402+
types.WorkspaceFullDocumentDiagnosticReport,
403+
types.WorkspaceUnchangedDocumentDiagnosticReport,
404+
]
405+
] = []
406+
407+
# Get all SQL and Python model files from the context
408+
for path, target in self.lsp_context.map.items():
409+
if isinstance(target, ModelTarget):
410+
uri = URI.from_path(path)
411+
diagnostics, result_id = self._get_diagnostics_for_uri(uri)
412+
413+
# Check if we have a previous result ID for this file
414+
previous_result_id = None
415+
if hasattr(params, "previous_result_ids") and params.previous_result_ids:
416+
for prev in params.previous_result_ids:
417+
if prev.uri == uri.value:
418+
previous_result_id = prev.value
419+
break
420+
421+
if previous_result_id and previous_result_id == result_id:
422+
# File hasn't changed
423+
items.append(
424+
types.WorkspaceUnchangedDocumentDiagnosticReport(
425+
kind=types.DocumentDiagnosticReportKind.Unchanged,
426+
result_id=result_id,
427+
uri=uri.value,
428+
)
429+
)
430+
else:
431+
# File has changed or is new
432+
items.append(
433+
types.WorkspaceFullDocumentDiagnosticReport(
434+
kind=types.DocumentDiagnosticReportKind.Full,
435+
result_id=result_id,
436+
uri=uri.value,
437+
items=diagnostics,
438+
)
439+
)
440+
441+
return types.WorkspaceDiagnosticReport(items=items)
442+
443+
except Exception as e:
444+
ls.show_message(
445+
f"Error getting workspace diagnostics: {e}", types.MessageType.Error
446+
)
447+
return types.WorkspaceDiagnosticReport(items=[])
448+
449+
def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], str]:
450+
"""Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
451+
# Check if we have cached diagnostics
452+
if uri in self.lint_cache:
453+
diagnostics, result_id = self.lint_cache[uri]
454+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
455+
456+
# Try to get diagnostics by loading context and linting
457+
try:
458+
context = self._context_get_or_load(uri)
459+
models = context.map[uri.to_path()]
460+
if models is None or not isinstance(models, ModelTarget):
461+
return [], "0"
462+
463+
# Lint the models and cache the results
464+
diagnostics = context.context.lint_models(
465+
models.names,
466+
raise_on_error=False,
467+
)
468+
self._diagnostic_version_counter += 1
469+
result_id = str(self._diagnostic_version_counter)
470+
self.lint_cache[uri] = (diagnostics, result_id)
471+
return SQLMeshLanguageServer._diagnostics_to_lsp_diagnostics(diagnostics), result_id
472+
except Exception:
473+
# If we can't get diagnostics, return empty list with no result ID
474+
return [], "0"
475+
330476
def _context_get_or_load(self, document_uri: URI) -> LSPContext:
331477
if self.lsp_context is None:
332478
self._ensure_context_for_document(document_uri)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import { test, expect } from '@playwright/test';
2+
import path from 'path';
3+
import fs from 'fs-extra';
4+
import os from 'os';
5+
import { startVSCode, SUSHI_SOURCE_PATH } from './utils';
6+
7+
test('Workspace diagnostics show up in the diagnostics panel', async () => {
8+
const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'vscode-test-sushi-'));
9+
await fs.copy(SUSHI_SOURCE_PATH, tempDir);
10+
11+
const configPath = path.join(tempDir, 'config.py');
12+
const configContent = await fs.readFile(configPath, 'utf8');
13+
const updatedContent = configContent.replace('enabled=False', 'enabled=True');
14+
await fs.writeFile(configPath, updatedContent);
15+
16+
try {
17+
const { window, close } = await startVSCode(tempDir);
18+
19+
// Wait for the models folder to be visible
20+
await window.waitForSelector('text=models');
21+
22+
// Click on the models folder, excluding external_models
23+
await window.getByRole('treeitem', { name: 'models', exact: true }).locator('a').click();
24+
25+
// Open the customer_revenue_lifetime model
26+
await window.getByRole('treeitem', { name: 'customers.sql', exact: true }).locator('a').click();
27+
28+
await
29+
30+
// Open problems panel
31+
await window.keyboard.press(process.platform === 'darwin' ? 'Meta+Shift+P' : 'Control+Shift+P');
32+
await window.keyboard.type('View: Focus Problems');
33+
await window.keyboard.press('Enter');
34+
35+
36+
await window.waitForSelector('text=problems');
37+
await window.waitForSelector("text=All models should have an owner");
38+
39+
await close();
40+
} finally {
41+
await fs.remove(tempDir);
42+
}
43+
});

0 commit comments

Comments
 (0)