@@ -57,7 +57,10 @@ def __init__(
5757 self .server = LanguageServer (server_name , version )
5858 self .context_class = context_class
5959 self .lsp_context : t .Optional [LSPContext ] = None
60- self .lint_cache : t .Dict [URI , t .List [AnnotatedRuleViolation ]] = {}
60+ # Cache stores tuples of (diagnostics, result_id)
61+ self .lint_cache : t .Dict [URI , t .Tuple [t .List [AnnotatedRuleViolation ], str ]] = {}
62+ self .client_supports_pull_diagnostics = False
63+ self ._diagnostic_version_counter = 0
6164
6265 # Register LSP features (e.g., formatting, hover, etc.)
6366 self ._register_features ()
@@ -69,6 +72,18 @@ def _register_features(self) -> None:
6972 def initialize (ls : LanguageServer , params : types .InitializeParams ) -> None :
7073 """Initialize the server when the client connects."""
7174 try :
75+ # Check if client supports pull diagnostics
76+ if params .capabilities and params .capabilities .text_document :
77+ diagnostics = getattr (params .capabilities .text_document , 'diagnostic' , None )
78+ if diagnostics :
79+ self .client_supports_pull_diagnostics = True
80+ ls .log_trace ("Client supports pull diagnostics" )
81+ else :
82+ self .client_supports_pull_diagnostics = False
83+ ls .log_trace ("Client does not support pull diagnostics" )
84+ else :
85+ self .client_supports_pull_diagnostics = False
86+
7287 if params .workspace_folders :
7388 # Try to find a SQLMesh config file in any workspace folder (only at the root level)
7489 for folder in params .workspace_folders :
@@ -153,61 +168,76 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153168 def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
154169 uri = URI (params .text_document .uri )
155170 context = self ._context_get_or_load (uri )
156- if self .lint_cache .get (uri ) is not None :
171+
172+ # Always update the cache
173+ models = context .map [uri .to_path ()]
174+ if models is None or not isinstance (models , ModelTarget ):
175+ return
176+
177+ if self .lint_cache .get (uri ) is None :
178+ diagnostics = context .context .lint_models (
179+ models .names ,
180+ raise_on_error = False ,
181+ )
182+ self ._diagnostic_version_counter += 1
183+ result_id = str (self ._diagnostic_version_counter )
184+ self .lint_cache [uri ] = (diagnostics , result_id )
185+
186+ # Only publish diagnostics if client doesn't support pull diagnostics
187+ if not self .client_supports_pull_diagnostics :
188+ diagnostics , _ = self .lint_cache [uri ]
157189 ls .publish_diagnostics (
158190 params .text_document .uri ,
159- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self . lint_cache [ uri ] ),
191+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
160192 )
161- return
162- models = context .map [uri .to_path ()]
163- if models is None :
164- return
165- if not isinstance (models , ModelTarget ):
166- return
167- self .lint_cache [uri ] = context .context .lint_models (
168- models .names ,
169- raise_on_error = False ,
170- )
171- ls .publish_diagnostics (
172- params .text_document .uri ,
173- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
174- )
175193
176194 @self .server .feature (types .TEXT_DOCUMENT_DID_CHANGE )
177195 def did_change (ls : LanguageServer , params : types .DidChangeTextDocumentParams ) -> None :
178196 uri = URI (params .text_document .uri )
179197 context = self ._context_get_or_load (uri )
180198 models = context .map [uri .to_path ()]
181- if models is None :
182- return
183- if not isinstance (models , ModelTarget ):
199+ if models is None or not isinstance (models , ModelTarget ):
184200 return
185- self .lint_cache [uri ] = context .context .lint_models (
201+
202+ # Always update the cache
203+ diagnostics = context .context .lint_models (
186204 models .names ,
187205 raise_on_error = False ,
188206 )
189- ls .publish_diagnostics (
190- params .text_document .uri ,
191- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
192- )
207+ self ._diagnostic_version_counter += 1
208+ result_id = str (self ._diagnostic_version_counter )
209+ self .lint_cache [uri ] = (diagnostics , result_id )
210+
211+ # Only publish diagnostics if client doesn't support pull diagnostics
212+ if not self .client_supports_pull_diagnostics :
213+ ls .publish_diagnostics (
214+ params .text_document .uri ,
215+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
216+ )
193217
194218 @self .server .feature (types .TEXT_DOCUMENT_DID_SAVE )
195219 def did_save (ls : LanguageServer , params : types .DidSaveTextDocumentParams ) -> None :
196220 uri = URI (params .text_document .uri )
197221 context = self ._context_get_or_load (uri )
198222 models = context .map [uri .to_path ()]
199- if models is None :
223+ if models is None or not isinstance ( models , ModelTarget ) :
200224 return
201- if not isinstance ( models , ModelTarget ):
202- return
203- self . lint_cache [ uri ] = context .context .lint_models (
225+
226+ # Always update the cache
227+ diagnostics = context .context .lint_models (
204228 models .names ,
205229 raise_on_error = False ,
206230 )
207- ls .publish_diagnostics (
208- params .text_document .uri ,
209- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
210- )
231+ self ._diagnostic_version_counter += 1
232+ result_id = str (self ._diagnostic_version_counter )
233+ self .lint_cache [uri ] = (diagnostics , result_id )
234+
235+ # Only publish diagnostics if client doesn't support pull diagnostics
236+ if not self .client_supports_pull_diagnostics :
237+ ls .publish_diagnostics (
238+ params .text_document .uri ,
239+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
240+ )
211241
212242 @self .server .feature (types .TEXT_DOCUMENT_FORMATTING )
213243 def formatting (
@@ -326,6 +356,117 @@ def goto_definition(
326356 except Exception as e :
327357 ls .show_message (f"Error getting references: { e } " , types .MessageType .Error )
328358 return []
359+
360+ @self .server .feature (types .TEXT_DOCUMENT_DIAGNOSTIC )
361+ def diagnostic (
362+ ls : LanguageServer , params : types .DocumentDiagnosticParams
363+ ) -> types .DocumentDiagnosticReport :
364+ """Handle diagnostic pull requests from the client."""
365+ try :
366+ uri = URI (params .text_document .uri )
367+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
368+
369+ # Check if client provided a previous result ID
370+ if hasattr (params , 'previous_result_id' ) and params .previous_result_id == result_id :
371+ # Return unchanged report if diagnostics haven't changed
372+ return types .RelatedUnchangedDocumentDiagnosticReport (
373+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
374+ result_id = result_id ,
375+ )
376+
377+ return types .RelatedFullDocumentDiagnosticReport (
378+ kind = types .DocumentDiagnosticReportKind .Full ,
379+ items = diagnostics ,
380+ result_id = result_id ,
381+ )
382+ except Exception as e :
383+ ls .show_message (f"Error getting diagnostics: { e } " , types .MessageType .Error )
384+ return types .RelatedFullDocumentDiagnosticReport (
385+ kind = types .DocumentDiagnosticReportKind .Full ,
386+ items = [],
387+ )
388+
389+ @self .server .feature (types .WORKSPACE_DIAGNOSTIC )
390+ def workspace_diagnostic (
391+ ls : LanguageServer , params : types .WorkspaceDiagnosticParams
392+ ) -> types .WorkspaceDiagnosticReport :
393+ """Handle workspace-wide diagnostic pull requests from the client."""
394+ try :
395+ if self .lsp_context is None :
396+ current_path = Path .cwd ()
397+ self ._ensure_context_in_folder (current_path )
398+
399+ if self .lsp_context is None :
400+ return types .WorkspaceDiagnosticReport (items = [])
401+
402+ items = []
403+
404+ # Get all SQL and Python model files from the context
405+ for path , target in self .lsp_context .map .items ():
406+ if isinstance (target , ModelTarget ):
407+ uri = URI .from_path (path )
408+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
409+
410+ # Check if we have a previous result ID for this file
411+ previous_result_id = None
412+ if hasattr (params , 'previous_result_ids' ) and params .previous_result_ids :
413+ for prev in params .previous_result_ids :
414+ if prev .uri == uri .value :
415+ previous_result_id = prev .value
416+ break
417+
418+ if previous_result_id and previous_result_id == result_id :
419+ # File hasn't changed
420+ items .append (
421+ types .WorkspaceUnchangedDocumentDiagnosticReport (
422+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
423+ result_id = result_id ,
424+ uri = uri .value ,
425+ )
426+ )
427+ else :
428+ # File has changed or is new
429+ items .append (
430+ types .WorkspaceFullDocumentDiagnosticReport (
431+ kind = types .DocumentDiagnosticReportKind .Full ,
432+ result_id = result_id ,
433+ uri = uri .value ,
434+ items = diagnostics ,
435+ )
436+ )
437+
438+ return types .WorkspaceDiagnosticReport (items = items )
439+
440+ except Exception as e :
441+ ls .show_message (f"Error getting workspace diagnostics: { e } " , types .MessageType .Error )
442+ return types .WorkspaceDiagnosticReport (items = [])
443+
444+ def _get_diagnostics_for_uri (self , uri : URI ) -> t .Tuple [t .List [types .Diagnostic ], str ]:
445+ """Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
446+ # Check if we have cached diagnostics
447+ if uri in self .lint_cache :
448+ diagnostics , result_id = self .lint_cache [uri ]
449+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ), result_id
450+
451+ # Try to get diagnostics by loading context and linting
452+ try :
453+ context = self ._context_get_or_load (uri )
454+ models = context .map [uri .to_path ()]
455+ if models is None or not isinstance (models , ModelTarget ):
456+ return [], "0"
457+
458+ # Lint the models and cache the results
459+ diagnostics = context .context .lint_models (
460+ models .names ,
461+ raise_on_error = False ,
462+ )
463+ self ._diagnostic_version_counter += 1
464+ result_id = str (self ._diagnostic_version_counter )
465+ self .lint_cache [uri ] = (diagnostics , result_id )
466+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ), result_id
467+ except Exception :
468+ # If we can't get diagnostics, return empty list with no result ID
469+ return [], "0"
329470
330471 def _context_get_or_load (self , document_uri : URI ) -> LSPContext :
331472 if self .lsp_context is None :
0 commit comments