@@ -57,7 +57,10 @@ def __init__(
5757 self .server = LanguageServer (server_name , version )
5858 self .context_class = context_class
5959 self .lsp_context : t .Optional [LSPContext ] = None
60- self .lint_cache : t .Dict [URI , t .List [AnnotatedRuleViolation ]] = {}
60+ # Cache stores tuples of (diagnostics, result_id)
61+ self .lint_cache : t .Dict [URI , t .Tuple [t .List [AnnotatedRuleViolation ], str ]] = {}
62+ self .client_supports_pull_diagnostics = False
63+ self ._diagnostic_version_counter = 0
6164
6265 # Register LSP features (e.g., formatting, hover, etc.)
6366 self ._register_features ()
@@ -69,6 +72,18 @@ def _register_features(self) -> None:
6972 def initialize (ls : LanguageServer , params : types .InitializeParams ) -> None :
7073 """Initialize the server when the client connects."""
7174 try :
75+ # Check if client supports pull diagnostics
76+ if params .capabilities and params .capabilities .text_document :
77+ diagnostics = getattr (params .capabilities .text_document , "diagnostic" , None )
78+ if diagnostics :
79+ self .client_supports_pull_diagnostics = True
80+ ls .log_trace ("Client supports pull diagnostics" )
81+ else :
82+ self .client_supports_pull_diagnostics = False
83+ ls .log_trace ("Client does not support pull diagnostics" )
84+ else :
85+ self .client_supports_pull_diagnostics = False
86+
7287 if params .workspace_folders :
7388 # Try to find a SQLMesh config file in any workspace folder (only at the root level)
7489 for folder in params .workspace_folders :
@@ -153,61 +168,76 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153168 def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
154169 uri = URI (params .text_document .uri )
155170 context = self ._context_get_or_load (uri )
156- if self .lint_cache .get (uri ) is not None :
171+
172+ # Always update the cache
173+ models = context .map [uri .to_path ()]
174+ if models is None or not isinstance (models , ModelTarget ):
175+ return
176+
177+ if self .lint_cache .get (uri ) is None :
178+ diagnostics = context .context .lint_models (
179+ models .names ,
180+ raise_on_error = False ,
181+ )
182+ self ._diagnostic_version_counter += 1
183+ result_id = str (self ._diagnostic_version_counter )
184+ self .lint_cache [uri ] = (diagnostics , result_id )
185+
186+ # Only publish diagnostics if client doesn't support pull diagnostics
187+ if not self .client_supports_pull_diagnostics :
188+ diagnostics , _ = self .lint_cache [uri ]
157189 ls .publish_diagnostics (
158190 params .text_document .uri ,
159- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self . lint_cache [ uri ] ),
191+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
160192 )
161- return
162- models = context .map [uri .to_path ()]
163- if models is None :
164- return
165- if not isinstance (models , ModelTarget ):
166- return
167- self .lint_cache [uri ] = context .context .lint_models (
168- models .names ,
169- raise_on_error = False ,
170- )
171- ls .publish_diagnostics (
172- params .text_document .uri ,
173- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
174- )
175193
176194 @self .server .feature (types .TEXT_DOCUMENT_DID_CHANGE )
177195 def did_change (ls : LanguageServer , params : types .DidChangeTextDocumentParams ) -> None :
178196 uri = URI (params .text_document .uri )
179197 context = self ._context_get_or_load (uri )
180198 models = context .map [uri .to_path ()]
181- if models is None :
199+ if models is None or not isinstance ( models , ModelTarget ) :
182200 return
183- if not isinstance ( models , ModelTarget ):
184- return
185- self . lint_cache [ uri ] = context .context .lint_models (
201+
202+ # Always update the cache
203+ diagnostics = context .context .lint_models (
186204 models .names ,
187205 raise_on_error = False ,
188206 )
189- ls .publish_diagnostics (
190- params .text_document .uri ,
191- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
192- )
207+ self ._diagnostic_version_counter += 1
208+ result_id = str (self ._diagnostic_version_counter )
209+ self .lint_cache [uri ] = (diagnostics , result_id )
210+
211+ # Only publish diagnostics if client doesn't support pull diagnostics
212+ if not self .client_supports_pull_diagnostics :
213+ ls .publish_diagnostics (
214+ params .text_document .uri ,
215+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
216+ )
193217
194218 @self .server .feature (types .TEXT_DOCUMENT_DID_SAVE )
195219 def did_save (ls : LanguageServer , params : types .DidSaveTextDocumentParams ) -> None :
196220 uri = URI (params .text_document .uri )
197221 context = self ._context_get_or_load (uri )
198222 models = context .map [uri .to_path ()]
199- if models is None :
200- return
201- if not isinstance (models , ModelTarget ):
223+ if models is None or not isinstance (models , ModelTarget ):
202224 return
203- self .lint_cache [uri ] = context .context .lint_models (
225+
226+ # Always update the cache
227+ diagnostics = context .context .lint_models (
204228 models .names ,
205229 raise_on_error = False ,
206230 )
207- ls .publish_diagnostics (
208- params .text_document .uri ,
209- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
210- )
231+ self ._diagnostic_version_counter += 1
232+ result_id = str (self ._diagnostic_version_counter )
233+ self .lint_cache [uri ] = (diagnostics , result_id )
234+
235+ # Only publish diagnostics if client doesn't support pull diagnostics
236+ if not self .client_supports_pull_diagnostics :
237+ ls .publish_diagnostics (
238+ params .text_document .uri ,
239+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
240+ )
211241
212242 @self .server .feature (types .TEXT_DOCUMENT_FORMATTING )
213243 def formatting (
@@ -327,6 +357,124 @@ def goto_definition(
327357 ls .show_message (f"Error getting references: { e } " , types .MessageType .Error )
328358 return []
329359
360+ @self .server .feature (types .TEXT_DOCUMENT_DIAGNOSTIC )
361+ def diagnostic (
362+ ls : LanguageServer , params : types .DocumentDiagnosticParams
363+ ) -> types .DocumentDiagnosticReport :
364+ """Handle diagnostic pull requests from the client."""
365+ try :
366+ uri = URI (params .text_document .uri )
367+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
368+
369+ # Check if client provided a previous result ID
370+ if hasattr (params , "previous_result_id" ) and params .previous_result_id == result_id :
371+ # Return unchanged report if diagnostics haven't changed
372+ return types .RelatedUnchangedDocumentDiagnosticReport (
373+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
374+ result_id = result_id ,
375+ )
376+
377+ return types .RelatedFullDocumentDiagnosticReport (
378+ kind = types .DocumentDiagnosticReportKind .Full ,
379+ items = diagnostics ,
380+ result_id = result_id ,
381+ )
382+ except Exception as e :
383+ ls .show_message (f"Error getting diagnostics: { e } " , types .MessageType .Error )
384+ return types .RelatedFullDocumentDiagnosticReport (
385+ kind = types .DocumentDiagnosticReportKind .Full ,
386+ items = [],
387+ )
388+
389+ @self .server .feature (types .WORKSPACE_DIAGNOSTIC )
390+ def workspace_diagnostic (
391+ ls : LanguageServer , params : types .WorkspaceDiagnosticParams
392+ ) -> types .WorkspaceDiagnosticReport :
393+ """Handle workspace-wide diagnostic pull requests from the client."""
394+ try :
395+ if self .lsp_context is None :
396+ current_path = Path .cwd ()
397+ self ._ensure_context_in_folder (current_path )
398+
399+ if self .lsp_context is None :
400+ return types .WorkspaceDiagnosticReport (items = [])
401+
402+ items : t .List [
403+ t .Union [
404+ types .WorkspaceFullDocumentDiagnosticReport ,
405+ types .WorkspaceUnchangedDocumentDiagnosticReport ,
406+ ]
407+ ] = []
408+
409+ # Get all SQL and Python model files from the context
410+ for path , target in self .lsp_context .map .items ():
411+ if isinstance (target , ModelTarget ):
412+ uri = URI .from_path (path )
413+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
414+
415+ # Check if we have a previous result ID for this file
416+ previous_result_id = None
417+ if hasattr (params , "previous_result_ids" ) and params .previous_result_ids :
418+ for prev in params .previous_result_ids :
419+ if prev .uri == uri .value :
420+ previous_result_id = prev .value
421+ break
422+
423+ if previous_result_id and previous_result_id == result_id :
424+ # File hasn't changed
425+ items .append (
426+ types .WorkspaceUnchangedDocumentDiagnosticReport (
427+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
428+ result_id = result_id ,
429+ uri = uri .value ,
430+ )
431+ )
432+ else :
433+ # File has changed or is new
434+ items .append (
435+ types .WorkspaceFullDocumentDiagnosticReport (
436+ kind = types .DocumentDiagnosticReportKind .Full ,
437+ result_id = result_id ,
438+ uri = uri .value ,
439+ items = diagnostics ,
440+ )
441+ )
442+
443+ return types .WorkspaceDiagnosticReport (items = items )
444+
445+ except Exception as e :
446+ ls .show_message (
447+ f"Error getting workspace diagnostics: { e } " , types .MessageType .Error
448+ )
449+ return types .WorkspaceDiagnosticReport (items = [])
450+
451+ def _get_diagnostics_for_uri (self , uri : URI ) -> t .Tuple [t .List [types .Diagnostic ], str ]:
452+ """Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
453+ # Check if we have cached diagnostics
454+ if uri in self .lint_cache :
455+ diagnostics , result_id = self .lint_cache [uri ]
456+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ), result_id
457+
458+ # Try to get diagnostics by loading context and linting
459+ try :
460+ context = self ._context_get_or_load (uri )
461+ models = context .map [uri .to_path ()]
462+ if models is None or not isinstance (models , ModelTarget ):
463+ return [], "0"
464+
465+ # Lint the models and cache the results
466+ diagnostics = context .context .lint_models (
467+ models .names ,
468+ raise_on_error = False ,
469+ )
470+ self ._diagnostic_version_counter += 1
471+ result_id = str (self ._diagnostic_version_counter )
472+ self .lint_cache [uri ] = (diagnostics , result_id )
473+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ), result_id
474+ except Exception :
475+ # If we can't get diagnostics, return empty list with no result ID
476+ return [], "0"
477+
330478 def _context_get_or_load (self , document_uri : URI ) -> LSPContext :
331479 if self .lsp_context is None :
332480 self ._ensure_context_for_document (document_uri )
0 commit comments