@@ -57,8 +57,12 @@ def __init__(
5757 self .server = LanguageServer (server_name , version )
5858 self .context_class = context_class
5959 self .lsp_context : t .Optional [LSPContext ] = None
60- self .lint_cache : t .Dict [URI , t .List [AnnotatedRuleViolation ]] = {}
6160
61+ # Cache stores tuples of (diagnostics, diagnostic_version)
62+ self .lint_cache : t .Dict [URI , t .Tuple [t .List [AnnotatedRuleViolation ], int ]] = {}
63+ self ._diagnostic_version_counter : int = 0
64+
65+ self .client_supports_pull_diagnostics = False
6266 # Register LSP features (e.g., formatting, hover, etc.)
6367 self ._register_features ()
6468
@@ -69,6 +73,18 @@ def _register_features(self) -> None:
6973 def initialize (ls : LanguageServer , params : types .InitializeParams ) -> None :
7074 """Initialize the server when the client connects."""
7175 try :
76+ # Check if client supports pull diagnostics
77+ if params .capabilities and params .capabilities .text_document :
78+ diagnostics = getattr (params .capabilities .text_document , "diagnostic" , None )
79+ if diagnostics :
80+ self .client_supports_pull_diagnostics = True
81+ ls .log_trace ("Client supports pull diagnostics" )
82+ else :
83+ self .client_supports_pull_diagnostics = False
84+ ls .log_trace ("Client does not support pull diagnostics" )
85+ else :
86+ self .client_supports_pull_diagnostics = False
87+
7288 if params .workspace_folders :
7389 # Try to find a SQLMesh config file in any workspace folder (only at the root level)
7490 for folder in params .workspace_folders :
@@ -153,61 +169,71 @@ def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]:
153169 def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
154170 uri = URI (params .text_document .uri )
155171 context = self ._context_get_or_load (uri )
156- if self .lint_cache .get (uri ) is not None :
172+ models = context .map [uri .to_path ()]
173+ if models is None or not isinstance (models , ModelTarget ):
174+ return
175+
176+ if self .lint_cache .get (uri ) is None :
177+ diagnostics = context .context .lint_models (
178+ models .names ,
179+ raise_on_error = False ,
180+ )
181+ self ._diagnostic_version_counter += 1
182+ self .lint_cache [uri ] = (diagnostics , self ._diagnostic_version_counter )
183+
184+ # Only publish diagnostics if client doesn't support pull diagnostics
185+ if not self .client_supports_pull_diagnostics :
186+ diagnostics , _ = self .lint_cache [uri ]
157187 ls .publish_diagnostics (
158188 params .text_document .uri ,
159- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self . lint_cache [ uri ] ),
189+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
160190 )
161- return
162- models = context .map [uri .to_path ()]
163- if models is None :
164- return
165- if not isinstance (models , ModelTarget ):
166- return
167- self .lint_cache [uri ] = context .context .lint_models (
168- models .names ,
169- raise_on_error = False ,
170- )
171- ls .publish_diagnostics (
172- params .text_document .uri ,
173- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
174- )
175191
176192 @self .server .feature (types .TEXT_DOCUMENT_DID_CHANGE )
177193 def did_change (ls : LanguageServer , params : types .DidChangeTextDocumentParams ) -> None :
178194 uri = URI (params .text_document .uri )
179195 context = self ._context_get_or_load (uri )
180196 models = context .map [uri .to_path ()]
181- if models is None :
197+ if models is None or not isinstance ( models , ModelTarget ) :
182198 return
183- if not isinstance ( models , ModelTarget ):
184- return
185- self . lint_cache [ uri ] = context .context .lint_models (
199+
200+ # Always update the cache
201+ diagnostics = context .context .lint_models (
186202 models .names ,
187203 raise_on_error = False ,
188204 )
189- ls .publish_diagnostics (
190- params .text_document .uri ,
191- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
192- )
205+ self ._diagnostic_version_counter += 1
206+ self .lint_cache [uri ] = (diagnostics , self ._diagnostic_version_counter )
207+
208+ # Only publish diagnostics if client doesn't support pull diagnostics
209+ if not self .client_supports_pull_diagnostics :
210+ ls .publish_diagnostics (
211+ params .text_document .uri ,
212+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
213+ )
193214
194215 @self .server .feature (types .TEXT_DOCUMENT_DID_SAVE )
195216 def did_save (ls : LanguageServer , params : types .DidSaveTextDocumentParams ) -> None :
196217 uri = URI (params .text_document .uri )
197218 context = self ._context_get_or_load (uri )
198219 models = context .map [uri .to_path ()]
199- if models is None :
200- return
201- if not isinstance (models , ModelTarget ):
220+ if models is None or not isinstance (models , ModelTarget ):
202221 return
203- self .lint_cache [uri ] = context .context .lint_models (
222+
223+ # Always update the cache
224+ diagnostics = context .context .lint_models (
204225 models .names ,
205226 raise_on_error = False ,
206227 )
207- ls .publish_diagnostics (
208- params .text_document .uri ,
209- SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (self .lint_cache [uri ]),
210- )
228+ self ._diagnostic_version_counter += 1
229+ self .lint_cache [uri ] = (diagnostics , self ._diagnostic_version_counter )
230+
231+ # Only publish diagnostics if client doesn't support pull diagnostics
232+ if not self .client_supports_pull_diagnostics :
233+ ls .publish_diagnostics (
234+ params .text_document .uri ,
235+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ),
236+ )
211237
212238 @self .server .feature (types .TEXT_DOCUMENT_FORMATTING )
213239 def formatting (
@@ -327,6 +353,125 @@ def goto_definition(
327353 ls .show_message (f"Error getting references: { e } " , types .MessageType .Error )
328354 return []
329355
356+ @self .server .feature (types .TEXT_DOCUMENT_DIAGNOSTIC )
357+ def diagnostic (
358+ ls : LanguageServer , params : types .DocumentDiagnosticParams
359+ ) -> types .DocumentDiagnosticReport :
360+ """Handle diagnostic pull requests from the client."""
361+ try :
362+ uri = URI (params .text_document .uri )
363+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
364+
365+ # Check if client provided a previous result ID
366+ if hasattr (params , "previous_result_id" ) and params .previous_result_id == result_id :
367+ # Return unchanged report if diagnostics haven't changed
368+ return types .RelatedUnchangedDocumentDiagnosticReport (
369+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
370+ result_id = str (result_id ),
371+ )
372+
373+ return types .RelatedFullDocumentDiagnosticReport (
374+ kind = types .DocumentDiagnosticReportKind .Full ,
375+ items = diagnostics ,
376+ result_id = str (result_id ),
377+ )
378+ except Exception as e :
379+ ls .show_message (f"Error getting diagnostics: { e } " , types .MessageType .Error )
380+ return types .RelatedFullDocumentDiagnosticReport (
381+ kind = types .DocumentDiagnosticReportKind .Full ,
382+ items = [],
383+ )
384+
385+ @self .server .feature (types .WORKSPACE_DIAGNOSTIC )
386+ def workspace_diagnostic (
387+ ls : LanguageServer , params : types .WorkspaceDiagnosticParams
388+ ) -> types .WorkspaceDiagnosticReport :
389+ """Handle workspace-wide diagnostic pull requests from the client."""
390+ try :
391+ if self .lsp_context is None :
392+ current_path = Path .cwd ()
393+ self ._ensure_context_in_folder (current_path )
394+
395+ if self .lsp_context is None :
396+ return types .WorkspaceDiagnosticReport (items = [])
397+
398+ items : t .List [
399+ t .Union [
400+ types .WorkspaceFullDocumentDiagnosticReport ,
401+ types .WorkspaceUnchangedDocumentDiagnosticReport ,
402+ ]
403+ ] = []
404+
405+ # Get all SQL and Python model files from the context
406+ for path , target in self .lsp_context .map .items ():
407+ if isinstance (target , ModelTarget ):
408+ uri = URI .from_path (path )
409+ diagnostics , result_id = self ._get_diagnostics_for_uri (uri )
410+
411+ # Check if we have a previous result ID for this file
412+ previous_result_id = None
413+ if hasattr (params , "previous_result_ids" ) and params .previous_result_ids :
414+ for prev in params .previous_result_ids :
415+ if prev .uri == uri .value :
416+ previous_result_id = prev .value
417+ break
418+
419+ if previous_result_id and previous_result_id == result_id :
420+ # File hasn't changed
421+ items .append (
422+ types .WorkspaceUnchangedDocumentDiagnosticReport (
423+ kind = types .DocumentDiagnosticReportKind .Unchanged ,
424+ result_id = str (result_id ),
425+ uri = uri .value ,
426+ )
427+ )
428+ else :
429+ # File has changed or is new
430+ items .append (
431+ types .WorkspaceFullDocumentDiagnosticReport (
432+ kind = types .DocumentDiagnosticReportKind .Full ,
433+ result_id = str (result_id ),
434+ uri = uri .value ,
435+ items = diagnostics ,
436+ )
437+ )
438+
439+ return types .WorkspaceDiagnosticReport (items = items )
440+
441+ except Exception as e :
442+ ls .show_message (
443+ f"Error getting workspace diagnostics: { e } " , types .MessageType .Error
444+ )
445+ return types .WorkspaceDiagnosticReport (items = [])
446+
447+ def _get_diagnostics_for_uri (self , uri : URI ) -> t .Tuple [t .List [types .Diagnostic ], int ]:
448+ """Get diagnostics for a specific URI, returning (diagnostics, result_id)."""
449+ # Check if we have cached diagnostics
450+ if uri in self .lint_cache :
451+ diagnostics , result_id = self .lint_cache [uri ]
452+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (diagnostics ), result_id
453+
454+ # Try to get diagnostics by loading context and linting
455+ try :
456+ context = self ._context_get_or_load (uri )
457+ models = context .map [uri .to_path ()]
458+ if models is None or not isinstance (models , ModelTarget ):
459+ return [], 0
460+
461+ # Lint the models and cache the results
462+ diagnostics = context .context .lint_models (
463+ models .names ,
464+ raise_on_error = False ,
465+ )
466+ self ._diagnostic_version_counter += 1
467+ self .lint_cache [uri ] = (diagnostics , self ._diagnostic_version_counter )
468+ return SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (
469+ diagnostics
470+ ), self ._diagnostic_version_counter
471+ except Exception :
472+ # If we can't get diagnostics, return empty list with no result ID
473+ return [], 0
474+
330475 def _context_get_or_load (self , document_uri : URI ) -> LSPContext :
331476 if self .lsp_context is None :
332477 self ._ensure_context_for_document (document_uri )
0 commit comments