11#!/usr/bin/env python
22"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
3+ #!/usr/bin/env python
4+ """A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
35
46import asyncio
57import gc
8+ import io
69import logging
710import re
811import typing as t
1720from lsprotocol import types
1821from pygls .server import LanguageServer
1922from pygls .workspace import TextDocument
23+ from sqlglot .errors import ParseError
2024from sqlmesh .core .dialect import format_model_expressions , parse
25+ from sqlmesh .core .lineage import column_description
26+ from sqlmesh .core .model import SqlModel
27+ from sqlmesh .utils import type_is_known
2128
2229logger = logging .getLogger (__name__ )
2330
3441"""A locking mechanism for ensuring that context mutation is thread-safe."""
3542
3643loop = asyncio .get_event_loop ()
37-
3844server = LanguageServer ("sqlmesh-lsp" , "v0.1.0" , loop = loop )
3945
4046
@@ -93,9 +99,34 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
9399 return document
94100
95101
102+ @server .feature (types .TEXT_DOCUMENT_COMPLETION )
103+ async def completions (ls : LanguageServer , params : types .CompletionParams ):
104+ """Provide completions based on upstream model column information."""
105+ items = []
106+ document = await ensure_context_for_document (ls .workspace .get_document (params .text_document .uri ))
107+ context , model = PATHS_TO_MODELS .get (document .path , (None , None ))
108+ if context is None or model is None :
109+ return types .CompletionList (is_incomplete = False , items = [])
110+ for dep in model .depends_on :
111+ model_dep = context .models [dep ]
112+ if model_dep .columns_to_types :
113+ for column , type_ in model_dep .columns_to_types .items ():
114+ items .append (
115+ types .CompletionItem (
116+ label = column ,
117+ label_details = types .CompletionItemLabelDetails (detail = type_ .sql ()),
118+ documentation = f"Source: { dep } \n \n "
119+ + (column_description (context , dep , column ) or "No description available" ),
120+ kind = types .CompletionItemKind .Field ,
121+ )
122+ )
123+ return types .CompletionList (is_incomplete = False , items = items )
124+
125+
96126@server .feature (types .TEXT_DOCUMENT_FORMATTING )
97- async def formatting (ls : LanguageServer , params : types .DocumentFormattingParams ):
127+ async def formatting (ls : LanguageServer , params : types .DocumentFormattingParams ) -> t . List [ types . TextEdit ] :
98128 """Format the document based using SQLMesh format_model_expressions."""
129+ logger .info (f"Formatting document: { params .text_document .uri } " )
99130 document = await ensure_context_for_document (ls .workspace .get_document (params .text_document .uri ))
100131 context , model = PATHS_TO_MODELS .get (document .path , (None , None ))
101132 if context is None or model is None :
@@ -104,7 +135,8 @@ async def formatting(ls: LanguageServer, params: types.DocumentFormattingParams)
104135 dialect = model .dialect if model and model .is_sql else default_dialect
105136 try :
106137 expressions = parse (document .source , default_dialect = dialect )
107- except Exception :
138+ except Exception as e :
139+ logger .error (f"Exception occurred while parsing document: { e } " )
108140 return []
109141 try :
110142 fmt_doc = format_model_expressions (expressions , dialect , ** context .config .format .generator_options )
@@ -152,6 +184,66 @@ def _iter_match_ranges_in_projection(term: str, source: str):
152184 )
153185
154186
187+ def _update_diagnostics (document : TextDocument ) -> None :
188+ """Update diagnostics for the given document."""
189+ WORKSPACE_DIAGNOSTICS [document .uri ] = (document .version , diagnostics := [])
190+ context , model = PATHS_TO_MODELS .get (document .path , (None , None ))
191+ if context is None or model is None :
192+ return
193+
194+ default_dialect = context .default_dialect
195+ dialect = model .dialect if model and model .is_sql else default_dialect
196+
197+ try :
198+ _ = parse (document .source , default_dialect = dialect )
199+ except ParseError as e :
200+ for error in e .errors :
201+ line = error ["line" ]
202+ comments_before_line = [
203+ _l for _l in document .lines [:line ] if _l .strip ().startswith (("/*" , "--" ))
204+ ] # This is just a hack to adjust the line number, not a proper solution but it works
205+ line -= len (comments_before_line )
206+ diagnostics .append (
207+ types .Diagnostic (
208+ message = e .args [0 ],
209+ severity = types .DiagnosticSeverity .Error ,
210+ range = types .Range (
211+ start = types .Position (line = line , character = error ["col" ]),
212+ end = types .Position (line = line , character = error ["col" ]),
213+ ),
214+ )
215+ )
216+
217+ if model is not None and isinstance (model , SqlModel ):
218+ sqlmesh_renderer_logger = logging .getLogger ("sqlmesh.core.renderer" )
219+ buf = io .StringIO ()
220+ interceptor = logging .StreamHandler (stream = buf )
221+ interceptor .setLevel (logging .WARNING )
222+ interceptor .setFormatter (logging .Formatter ("%(message)s" ))
223+ sqlmesh_renderer_logger .addHandler (interceptor )
224+ _ = model ._query_renderer .render (execution_time = "now" )
225+ sqlmesh_renderer_logger .removeHandler (interceptor )
226+ buf .seek (0 )
227+ warnings = buf .read ().strip ()
228+ if warnings :
229+ diagnostics .append (
230+ types .Diagnostic (message = warnings , severity = types .DiagnosticSeverity .Warning , range = _top_of_file )
231+ )
232+
233+ setattr (model , "_columns_to_types" , None ) # clear cached columns to types
234+ if model and model .columns_to_types :
235+ for column , type_ in model .columns_to_types .items ():
236+ if not type_is_known (type_ ):
237+ for range_ in _iter_match_ranges_in_projection (column , document .source ):
238+ diagnostics .append (
239+ types .Diagnostic (
240+ message = f"Unknown type for column: { column } - add a type hint to final projection" ,
241+ severity = types .DiagnosticSeverity .Warning ,
242+ range = range_ ,
243+ )
244+ )
245+
246+
155247@server .feature (types .TEXT_DOCUMENT_DID_OPEN )
156248async def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ):
157249 """Update diagnostics on document open and refresh context if necessary."""
@@ -166,6 +258,16 @@ async def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
166258 PATHS_TO_MODELS .update (
167259 {str (model ._path .resolve ()): (context , weakref .proxy (model )) for model in context .models .values ()}
168260 )
261+ _update_diagnostics (document )
262+ for uri , (version , diagnostics ) in WORKSPACE_DIAGNOSTICS .items ():
263+ ls .publish_diagnostics (uri = uri , version = version , diagnostics = diagnostics )
264+
265+
266+ @server .feature (types .TEXT_DOCUMENT_DID_CLOSE )
267+ async def did_close (ls : LanguageServer , params : types .DidCloseTextDocumentParams ):
268+ """Remove diagnostics on document close."""
269+ if params .text_document .uri in WORKSPACE_DIAGNOSTICS :
270+ del WORKSPACE_DIAGNOSTICS [params .text_document .uri ]
169271
170272
171273@server .feature (types .TEXT_DOCUMENT_DID_SAVE )
@@ -181,6 +283,18 @@ async def did_save(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
181283 if model ._path == Path (document .path ):
182284 PATHS_TO_MODELS [document .path ] = (context , weakref .proxy (model ))
183285 break
286+ _update_diagnostics (document )
287+ for uri , (version , diagnostics ) in WORKSPACE_DIAGNOSTICS .items ():
288+ ls .publish_diagnostics (uri = uri , version = version , diagnostics = diagnostics )
289+
290+
291+ @server .feature (types .TEXT_DOCUMENT_DID_CHANGE )
292+ async def did_change (ls : LanguageServer , params : types .DidOpenTextDocumentParams ):
293+ """Update diagnostics on document change."""
294+ document = await ensure_context_for_document (ls .workspace .get_text_document (params .text_document .uri ))
295+ _update_diagnostics (document )
296+ for uri , (version , diagnostics ) in WORKSPACE_DIAGNOSTICS .items ():
297+ ls .publish_diagnostics (uri = uri , version = version , diagnostics = diagnostics )
184298
185299
186300@server .feature (types .WORKSPACE_DID_CHANGE_WATCHED_FILES )
@@ -189,6 +303,9 @@ async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWa
189303 updated = {}
190304 for change in params .changes :
191305 document = await ensure_context_for_document (ls .workspace .get_text_document (change .uri ))
306+ if change .type == types .FileChangeType .Changed :
307+ _update_diagnostics (document )
308+ continue
192309 path = Path (document .path )
193310 known_paths = PATHS_TO_MODELS .keys ()
194311 if change .type == types .FileChangeType .Deleted and str (path ) in known_paths :
@@ -215,8 +332,11 @@ async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWa
215332 )
216333 updated [context .path ] = True
217334
335+
218336def main ():
337+ logging .basicConfig (level = logging .DEBUG )
219338 server .start_io ()
220339
340+
221341if __name__ == "__main__" :
222342 main ()
0 commit comments