11#!/usr/bin/env python
22"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33
4+ from collections import defaultdict
45import logging
56import typing as t
67from pathlib import Path
78
89from lsprotocol import types
910from pygls .server import LanguageServer
10- from pygls .workspace import TextDocument
1111
1212from sqlmesh ._version import __version__
1313from sqlmesh .core .context import Context
14+ from sqlmesh .core .linter .definition import AnnotatedRuleViolation
15+
16+
17+ class LSPContext :
18+ """
19+ A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20+ """
21+
22+ def __init__ (self , context : Context ) -> None :
23+ self .context = context
24+ map : t .Dict [str , t .List [str ]] = defaultdict (list [str ])
25+ for model in context .models .values ():
26+ if model ._path is None :
27+ path = Path (model ._path ).resolve ()
28+ map [f"file://{ path .as_posix ()} " ].append (model .name )
29+
30+ self .map = map
1431
1532
1633class SQLMeshLanguageServer :
@@ -27,29 +44,88 @@ def __init__(
2744 """
2845 self .server = LanguageServer (server_name , version )
2946 self .context_class = context_class
30- self .context : t .Optional [Context ] = None
47+ # A tuple of (context, reverse_map) where the reverse_map is uri to model name
48+ self .context : t .Optional [LSPContext ] = None
49+ self .lint_cache : t .Dict [str , t .List [AnnotatedRuleViolation ]] = {}
3150
3251 # Register LSP features (e.g., formatting, hover, etc.)
3352 self ._register_features ()
3453
3554 def _register_features (self ) -> None :
3655 """Register LSP features on the internal LanguageServer instance."""
3756
57+ @self .server .feature (types .TEXT_DOCUMENT_DID_OPEN )
58+ def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
59+ context = self ._context_get_or_load (params .text_document .uri )
60+ if self .lint_cache .get (params .text_document .uri ) is not None :
61+ ls .publish_diagnostics (
62+ params .text_document .uri ,
63+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (
64+ self .lint_cache [params .text_document .uri ]
65+ ),
66+ )
67+ return
68+ models = context .map [params .text_document .uri ]
69+ if models is None :
70+ return
71+ self .lint_cache [params .text_document .uri ] = context .context .lint_models (
72+ models ,
73+ raise_on_error = False ,
74+ )
75+ ls .publish_diagnostics (
76+ params .text_document .uri ,
77+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (
78+ self .lint_cache [params .text_document .uri ]
79+ ),
80+ )
81+
82+ @self .server .feature (types .TEXT_DOCUMENT_DID_CHANGE )
83+ def did_change (ls : LanguageServer , params : types .DidChangeTextDocumentParams ) -> None :
84+ context = self ._context_get_or_load (params .text_document .uri )
85+ models = context .map [params .text_document .uri ]
86+ if models is None :
87+ return
88+ self .lint_cache [params .text_document .uri ] = context .context .lint_models (
89+ models ,
90+ raise_on_error = False ,
91+ )
92+ ls .publish_diagnostics (
93+ params .text_document .uri ,
94+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (
95+ self .lint_cache [params .text_document .uri ]
96+ ),
97+ )
98+
99+ @self .server .feature (types .TEXT_DOCUMENT_DID_SAVE )
100+ def did_save (ls : LanguageServer , params : types .DidSaveTextDocumentParams ) -> None :
101+ context = self ._context_get_or_load (params .text_document .uri )
102+ models = context .map [params .text_document .uri ]
103+ if models is None :
104+ return
105+ self .lint_cache [params .text_document .uri ] = context .context .lint_models (
106+ models ,
107+ raise_on_error = False ,
108+ )
109+ ls .publish_diagnostics (
110+ params .text_document .uri ,
111+ SQLMeshLanguageServer ._diagnostics_to_lsp_diagnostics (
112+ self .lint_cache [params .text_document .uri ]
113+ ),
114+ )
115+
38116 @self .server .feature (types .TEXT_DOCUMENT_FORMATTING )
39117 def formatting (
40118 ls : LanguageServer , params : types .DocumentFormattingParams
41119 ) -> t .List [types .TextEdit ]:
42120 """Format the document using SQLMesh `format_model_expressions`."""
43121 try :
44- document = self .ensure_context_for_document (
45- ls .workspace .get_document (params .text_document .uri )
46- )
47-
122+ self ._ensure_context_for_document (params .text_document .uri )
123+ document = ls .workspace .get_document (params .text_document .uri )
48124 if self .context is None :
49125 raise RuntimeError (f"No context found for document: { document .path } " )
50126
51127 # Perform formatting using the loaded context
52- self .context .format (paths = (Path (document .path ),))
128+ self .context .context . format (paths = (Path (document .path ),))
53129 with open (document .path , "r+" , encoding = "utf-8" ) as file :
54130 new_text = file .read ()
55131
@@ -70,20 +146,32 @@ def formatting(
70146 ls .show_message (f"Error formatting SQL: { e } " , types .MessageType .Error )
71147 return []
72148
73- def ensure_context_for_document (self , document : TextDocument ) -> TextDocument :
149+ def _context_get_or_load (self , document_uri : str ) -> LSPContext :
150+ if self .context is None :
151+ self ._ensure_context_for_document (document_uri )
152+ if self .context is None :
153+ raise RuntimeError ("No context found" )
154+ return self .context
155+
156+ def _ensure_context_for_document (
157+ self ,
158+ document_uri : str ,
159+ ) -> None :
74160 """
75161 Ensure that a context exists for the given document if applicable by searching
76162 for a config.py or config.yml file in the parent directories.
77163 """
78164 # If the context is already loaded, check if this document belongs to it.
79165 if self .context is not None :
80- self .context .load () # Reload or refresh context
81- return document
166+ context = self .context
167+ context .context .load () # Reload or refresh context
168+ self .context = LSPContext (context .context )
169+ return
82170
83171 # No context yet: try to find config and load it
84- path = Path (document . path ).resolve ()
172+ path = Path (self . _uri_to_path ( document_uri ) ).resolve ()
85173 if path .suffix not in (".sql" , ".py" ):
86- return document
174+ return
87175
88176 loaded = False
89177 # Ascend directories to look for config
@@ -93,18 +181,57 @@ def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
93181 if config_path .exists ():
94182 try :
95183 # Use user-provided instantiator to build the context
96- self . context = self .context_class (paths = [path ])
97- self .server . show_message ( f"Context loaded for: { path } " )
184+ created_context = self .context_class (paths = [path ])
185+ self .context = LSPContext ( created_context )
98186 loaded = True
99187 # Re-check context for document now that it's loaded
100- return self .ensure_context_for_document ( document )
188+ return self ._ensure_context_for_document ( document_uri )
101189 except Exception as e :
102190 self .server .show_message (
103191 f"Error loading context: { e } " , types .MessageType .Error
104192 )
105193 path = path .parent
106194
107- return document
195+ return
196+
197+ @staticmethod
198+ def _diagnostic_to_lsp_diagnostic (
199+ diagnostic : AnnotatedRuleViolation ,
200+ ) -> t .Optional [types .Diagnostic ]:
201+ if diagnostic .model ._path is None :
202+ return None
203+ with open (diagnostic .model ._path , "r" , encoding = "utf-8" ) as file :
204+ lines = file .readlines ()
205+ return types .Diagnostic (
206+ range = types .Range (
207+ start = types .Position (line = 0 , character = 0 ),
208+ end = types .Position (line = len (lines ), character = len (lines [- 1 ])),
209+ ),
210+ message = diagnostic .violation_msg ,
211+ severity = types .DiagnosticSeverity .Error
212+ if diagnostic .violation_type == "error"
213+ else types .DiagnosticSeverity .Warning ,
214+ )
215+
216+ @staticmethod
217+ def _diagnostics_to_lsp_diagnostics (
218+ diagnostics : t .List [AnnotatedRuleViolation ],
219+ ) -> t .List [types .Diagnostic ]:
220+ lsp_diagnostics : t .List [types .Diagnostic ] = []
221+ for diagnostic in diagnostics :
222+ if diagnostic is None :
223+ continue
224+ lsp_diagnostic = SQLMeshLanguageServer ._diagnostic_to_lsp_diagnostic (diagnostic )
225+ if lsp_diagnostic is not None :
226+ lsp_diagnostics .append (lsp_diagnostic )
227+ return lsp_diagnostics
228+
229+ @staticmethod
230+ def _uri_to_path (uri : str ) -> str :
231+ """Convert a URI to a path."""
232+ if uri .startswith ("file://" ):
233+ return Path (uri [7 :]).resolve ().as_posix ()
234+ return Path (uri ).resolve ().as_posix ()
108235
109236 def start (self ) -> None :
110237 """Start the server with I/O transport."""
0 commit comments