@@ -57,7 +57,10 @@ def __init__(
5757 self .server = LanguageServer (server_name , version )
5858 self .context_class = context_class
5959 self .lsp_context : t .Optional [LSPContext ] = None
60- self .lint_cache : t .Dict [URI , t .List [AnnotatedRuleViolation ]] = {}
60+ # Cache stores tuples of (diagnostics, result_id)
61+ self .lint_cache : t .Dict [URI , t .Tuple [t .List [AnnotatedRuleViolation ], str ]] = {}
62+ self .client_supports_pull_diagnostics = False
63+ self ._diagnostic_version_counter = 0
6164
6265 # Register LSP features (e.g., formatting, hover, etc.)
6366 self ._register_features ()
@@ -69,6 +72,18 @@ def _register_features(self) -> None:
6972 def initialize (ls : LanguageServer , params : types .InitializeParams ) -> None :
7073 """Initialize the server when the client connects."""
7174 try :
75+ # Check if client supports pull diagnostics
76+ if params .capabilities and params .capabilities .text_document :
77+ diagnostics = getattr (params .capabilities .text_document , "diagnostic" , None )
78+ if diagnostics :
79+ self .client_supports_pull_diagnostics = True
80+ ls .log_trace ("Client supports pull diagnostics" )
81+ else :
82+ self .client_supports_pull_diagnostics = False
83+ ls .log_trace ("Client does not support pull diagnostics" )
84+ else :
85+ self .client_supports_pull_diagnostics = False
86+
7287 if params .workspace_folders :
7388 # Try to find a SQLMesh config file in any workspace folder (only at the root level)
7489 for folder in params .workspace_folders :
@@ -153,61 +168,74 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153168 def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
154169 uri = URI (params .text_document .uri )
155170 context = self ._context_get_or_load (uri )
156- if self .lint_cache .get (uri ) is not None :
171+ models = context .map [uri .to_path ()]
172+ if models is None or not isinstance (models , ModelTarget ):
173+ return
174+
175+ if self .lint_cache .get (uri ) is None :
176+ diagnostics = context .context .lint_models (
177+ models .names ,
178+ raise_on_error = False ,
179+ )
180+ self ._diagnostic_version_counter += 1
181+ result_id = str (self ._diagnostic_version_counter )
182+ self .lint_cache [uri ] = (diagnostics , result_id )
183+
184+ # Only publish diagnostics if client doesn't support pull diagnostics
185+ if not self .client_supports_pull_diagnostics :
186+ diagnostics , _ = self .lint_cache [uri ]
157187 ls .publish_diagnostics (
158188 params .text_document .uri ,
159- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self . lint_cache [ uri ] ),
189+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
160190 )
161- return
162- models = context .map [uri .to_path ()]
163- if models is None :
164- return
165- if not isinstance (models , ModelTarget ):
166- return
167- self .lint_cache [uri ] = context .context .lint_models (
168- models .names ,
169- raise_on_error = False ,
170- )
171- ls .publish_diagnostics (
172- params .text_document .uri ,
173- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
174- )
175191
176192 @self .server .feature (types .TEXT_DOCUMENT_DID_CHANGE )
177193 def did_change (ls : LanguageServer , params : types .DidChangeTextDocumentParams ) -> None :
178194 uri = URI (params .text_document .uri )
179195 context = self ._context_get_or_load (uri )
180196 models = context .map [uri .to_path ()]
181- if models is None :
182- return
183- if not isinstance (models , ModelTarget ):
197+ if models is None or not isinstance (models , ModelTarget ):
184198 return
185- self .lint_cache [uri ] = context .context .lint_models (
199+
200+ # Always update the cache
201+ diagnostics = context .context .lint_models (
186202 models .names ,
187203 raise_on_error = False ,
188204 )
189- ls .publish_diagnostics (
190- params .text_document .uri ,
191- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
192- )
205+ self ._diagnostic_version_counter += 1
206+ result_id = str (self ._diagnostic_version_counter )
207+ self .lint_cache [uri ] = (diagnostics , result_id )
208+
209+ # Only publish diagnostics if client doesn't support pull diagnostics
210+ if not self .client_supports_pull_diagnostics :
211+ ls .publish_diagnostics (
212+ params .text_document .uri ,
213+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
214+ )
193215
194216 @self .server .feature (types .TEXT_DOCUMENT_DID_SAVE )
195217 def did_save (ls : LanguageServer , params : types .DidSaveTextDocumentParams ) -> None :
196218 uri = URI (params .text_document .uri )
197219 context = self ._context_get_or_load (uri )
198220 models = context .map [uri .to_path ()]
199- if models is None :
200- return
201- if not isinstance (models , ModelTarget ):
221+ if models is None or not isinstance (models , ModelTarget ):
202222 return
203- self .lint_cache [uri ] = context .context .lint_models (
223+
224+ # Always update the cache
225+ diagnostics = context .context .lint_models (
204226 models .names ,
205227 raise_on_error = False ,
206228 )
207- ls .publish_diagnostics (
208- params .text_document .uri ,
209- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
210- )
229+ self ._diagnostic_version_counter += 1
230+ result_id = str (self ._diagnostic_version_counter )
231+ self .lint_cache [uri ] = (diagnostics , result_id )
232+
233+ # Only publish diagnostics if client doesn't support pull diagnostics
234+ if not self .client_supports_pull_diagnostics :
235+ ls .publish_diagnostics (
236+ params .text_document .uri ,
237+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
238+ )
211239
212240 @self .server .feature (types .TEXT_DOCUMENT_FORMATTING )
213241 def formatting (
@@ -327,6 +355,124 @@ def goto_definition(
327355 ls .show_message (f"Error getting references: { e } " , types .MessageType .Error )
328356 return []
329357
358+ @self .server .feature (types .TEXT_DOCUMENT_DIAGNOSTIC )
359+ def diagnostic (
360+ ls : LanguageServer , params : types .DocumentDiagnosticParams
361+ ) -> types .DocumentDiagnosticReport :
362+ """Handle diagnostic pull requests from the client."""
363+ try :
364+ uri = URI (params .text_document .uri )
365+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
366+
367+ # Check if client provided a previous result ID
368+ if hasattr (params , "previous_result_id" ) and params .previous_result_id == result_id :
369+ # Return unchanged report if diagnostics haven't changed
370+ return types .RelatedUnchangedDocumentDiagnosticReport (
371+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
372+ result_id = result_id ,
373+ )
374+
375+ return types .RelatedFullDocumentDiagnosticReport (
376+ kind = types .DocumentDiagnosticReportKind .Full ,
377+ items = diagnostics ,
378+ result_id = result_id ,
379+ )
380+ except Exception as e :
381+ ls .show_message (f"Error getting diagnostics: { e } " , types .MessageType .Error )
382+ return types .RelatedFullDocumentDiagnosticReport (
383+ kind = types .DocumentDiagnosticReportKind .Full ,
384+ items = [],
385+ )
386+
387+ @self .server .feature (types .WORKSPACE_DIAGNOSTIC )
388+ def workspace_diagnostic (
389+ ls : LanguageServer , params : types .WorkspaceDiagnosticParams
390+ ) -> types .WorkspaceDiagnosticReport :
391+ """Handle workspace-wide diagnostic pull requests from the client."""
392+ try :
393+ if self .lsp_context is None :
394+ current_path = Path .cwd ()
395+ self ._ensure_context_in_folder (current_path )
396+
397+ if self .lsp_context is None :
398+ return types .WorkspaceDiagnosticReport (items = [])
399+
400+ items : t .List [
401+ t .Union [
402+ types .WorkspaceFullDocumentDiagnosticReport ,
403+ types .WorkspaceUnchangedDocumentDiagnosticReport ,
404+ ]
405+ ] = []
406+
407+ # Get all SQL and Python model files from the context
408+ for path , target in self .lsp_context .map .items ():
409+ if isinstance (target , ModelTarget ):
410+ uri = URI .from_path (path )
411+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
412+
413+ # Check if we have a previous result ID for this file
414+ previous_result_id = None
415+ if hasattr (params , "previous_result_ids" ) and params .previous_result_ids :
416+ for prev in params .previous_result_ids :
417+ if prev .uri == uri .value :
418+ previous_result_id = prev .value
419+ break
420+
421+ if previous_result_id and previous_result_id == result_id :
422+ # File hasn't changed
423+ items .append (
424+ types .WorkspaceUnchangedDocumentDiagnosticReport (
425+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
426+ result_id = result_id ,
427+ uri = uri .value ,
428+ )
429+ )
430+ else :
431+ # File has changed or is new
432+ items .append (
433+ types .WorkspaceFullDocumentDiagnosticReport (
434+ kind = types .DocumentDiagnosticReportKind .Full ,
435+ result_id = result_id ,
436+ uri = uri .value ,
437+ items = diagnostics ,
438+ )
439+ )
440+
441+ return types .WorkspaceDiagnosticReport (items = items )
442+
443+ except Exception as e :
444+ ls .show_message (
445+ f"Error getting workspace diagnostics: { e } " , types .MessageType .Error
446+ )
447+ return types .WorkspaceDiagnosticReport (items = [])
448+
449+ def _get_diagnostics_for_uri (self , uri : URI ) -> t .Tuple [t .List [types .Diagnostic ], str ]:
450+ """Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
451+ # Check if we have cached diagnostics
452+ if uri in self .lint_cache :
453+ diagnostics , result_id = self .lint_cache [uri ]
454+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ), result_id
455+
456+ # Try to get diagnostics by loading context and linting
457+ try :
458+ context = self ._context_get_or_load (uri )
459+ models = context .map [uri .to_path ()]
460+ if models is None or not isinstance (models , ModelTarget ):
461+ return [], "0"
462+
463+ # Lint the models and cache the results
464+ diagnostics = context .context .lint_models (
465+ models .names ,
466+ raise_on_error = False ,
467+ )
468+ self ._diagnostic_version_counter += 1
469+ result_id = str (self ._diagnostic_version_counter )
470+ self .lint_cache [uri ] = (diagnostics , result_id )
471+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ), result_id
472+ except Exception :
473+ # If we can't get diagnostics, return empty list with no result ID
474+ return [], "0"
475+
330476 def _context_get_or_load (self , document_uri : URI ) -> LSPContext :
331477 if self .lsp_context is None :
332478 self ._ensure_context_for_document (document_uri )
0 commit comments