44import asyncio
55import gc
66import logging
7- import re
87import typing as t
98import weakref
109from collections import defaultdict
1110from contextlib import suppress
12- from functools import lru_cache
1311from itertools import cycle
1412from pathlib import Path
1513
1614import sqlmesh
1715from lsprotocol import types
1816from pygls .server import LanguageServer
1917from pygls .workspace import TextDocument
20- from sqlmesh .core . dialect import format_model_expressions , parse
18+ from sqlmesh ._version import __version__
2119
2220logger = logging .getLogger (__name__ )
2321
24- WORKSPACE_DIAGNOSTICS : t .Dict [str , t .Tuple [t .Optional [int ], t .List [types .Diagnostic ]]] = {}
25- """A mapping of document URIs to diagnostics."""
26-
2722CONTEXTS : t .Dict [str , sqlmesh .Context ] = {}
2823"""A mapping of workspace paths to SQLMesh contexts."""
2924
3025PATHS_TO_MODELS : t .Dict [str , t .Tuple [sqlmesh .Context , sqlmesh .Model ]] = {}
3126"""A mapping of file paths to SQLMesh (context, model) tuples."""
3227
33- C_MUTEX = defaultdict (asyncio .Lock )
28+ C_MUTEX : t . DefaultDict [ t . Union [ str , Path ], asyncio . Lock ] = defaultdict (asyncio .Lock )
3429"""A locking mechanism for ensuring that context mutation is thread-safe."""
3530
3631loop = asyncio .get_event_loop ()
3732
38- server = LanguageServer ("sqlmesh-lsp " , "v0.1.0" , loop = loop )
33+ server = LanguageServer ("sqlmesh_lsp " , __version__ , loop = loop )
3934
4035
4136async def refresh_context_loop (context : sqlmesh .Context ) -> None :
@@ -46,11 +41,15 @@ async def refresh_context_loop(context: sqlmesh.Context) -> None:
4641 gc_iter = cycle (list (range (10 )))
4742 while True :
4843 await asyncio .sleep (10.0 )
49- if context ._loader .reload_needed ():
50- async with C_MUTEX [context .path ]:
51- await asyncio .to_thread (context .load )
44+ for loader in context ._loaders :
45+ if loader .reload_needed ():
46+ async with C_MUTEX [context .path ]:
47+ await asyncio .to_thread (context .load )
5248 PATHS_TO_MODELS .update (
53- {str (model ._path .resolve ()): (context , weakref .proxy (model )) for model in context .models .values ()}
49+ {
50+ str (model ._path .resolve ()): (context , weakref .proxy (model ))
51+ for model in context .models .values ()
52+ }
5453 )
5554 if next (gc_iter ) == 0 :
5655 gc .collect ()
@@ -84,7 +83,10 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
8483 loop .create_task (refresh_context_loop (handle ))
8584 CONTEXTS [str (path )] = handle
8685 PATHS_TO_MODELS .update (
87- {str (model ._path .resolve ()): (handle , weakref .proxy (model )) for model in handle .models .values ()}
86+ {
87+ str (model ._path .resolve ()): (handle , weakref .proxy (model ))
88+ for model in handle .models .values ()
89+ }
8890 )
8991 server .show_message (f"Context loaded for: { path } " )
9092 loaded = True
@@ -94,99 +96,81 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
9496
9597
9698@server .feature (types .TEXT_DOCUMENT_FORMATTING )
97- async def formatting (ls : LanguageServer , params : types .DocumentFormattingParams ):
99+ async def formatting (
100+ ls : LanguageServer , params : types .DocumentFormattingParams
101+ ) -> t .List [types .TextEdit ]:
98102 """Format the document based using SQLMesh format_model_expressions."""
99- document = await ensure_context_for_document (ls .workspace .get_document (params .text_document .uri ))
100- context , model = PATHS_TO_MODELS .get (document .path , (None , None ))
101- if context is None or model is None :
102- return []
103- default_dialect = context .default_dialect
104- dialect = model .dialect if model and model .is_sql else default_dialect
105- try :
106- expressions = parse (document .source , default_dialect = dialect )
107- except Exception :
108- return []
109103 try :
110- fmt_doc = format_model_expressions (expressions , dialect , ** context .config .format .generator_options )
111- if context .config .format .append_newline :
112- fmt_doc += "\n "
104+ logger .info (f"Formatting document: { params .text_document .uri } " )
105+ document = await ensure_context_for_document (
106+ ls .workspace .get_document (params .text_document .uri )
107+ )
108+ context , model = PATHS_TO_MODELS .get (document .path , (None , None ))
109+ context .format (paths = [Path (document .path )])
110+ with open (document .path , "r+" , encoding = "utf-8" ) as file :
111+ return [
112+ types .TextEdit (
113+ range = types .Range (
114+ types .Position (0 , 0 ),
115+ types .Position (len (document .lines ), len (document .lines [- 1 ])),
116+ ),
117+ new_text = file .read (),
118+ )
119+ ]
113120 except Exception as e :
114121 ls .show_message (f"Error formatting SQL: { e } " , types .MessageType .Error )
115122 return []
116- return [
117- types .TextEdit (
118- range = types .Range (
119- types .Position (0 , 0 ),
120- types .Position (len (document .lines ), len (document .lines [- 1 ])),
121- ),
122- new_text = fmt_doc ,
123- )
124- ]
125-
126-
127- _top_of_file = types .Range (start = types .Position (line = 0 , character = 0 ), end = types .Position (line = 0 , character = 0 ))
128-
129- _cached_re_compile = t .cast (t .Callable [[str , re .RegexFlag ], re .Pattern [str ]], lru_cache (maxsize = 1024 )(re .compile ))
130-
131-
132- def _iter_match_ranges_in_projection (term : str , source : str ):
133- """Iterate over ranges of matches for a term in a SQL projection."""
134- col_patt = _cached_re_compile (rf'\b["`]?({ term } )["`]?,?' , re .IGNORECASE )
135- projection_patt = _cached_re_compile (r"SELECT\s+(.*)\s+FROM" , re .DOTALL | re .IGNORECASE )
136- for p_match in projection_patt .finditer (source ):
137- if not p_match .group (1 ):
138- continue
139- proj_start , proj_end = p_match .span (1 )
140- proj_substr = source [proj_start :proj_end ]
141- for c_match in col_patt .finditer (proj_substr ):
142- if not c_match .group (1 ):
143- continue
144- col_start , col_end = c_match .span (1 )
145- start , end = proj_start + col_start , proj_start + col_end
146- line = source .count ("\n " , 0 , start )
147- char_s = start - source .rfind ("\n " , 0 , start ) - 1
148- char_e = end - source .rfind ("\n " , 0 , end )
149- yield types .Range (
150- start = types .Position (line = line , character = char_s ),
151- end = types .Position (line = line , character = char_e ),
152- )
153123
154124
155125@server .feature (types .TEXT_DOCUMENT_DID_OPEN )
156- async def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ):
126+ async def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
157127 """Update diagnostics on document open and refresh context if necessary."""
158- document = await ensure_context_for_document (ls .workspace .get_document (params .text_document .uri ))
128+ document = await ensure_context_for_document (
129+ ls .workspace .get_document (params .text_document .uri )
130+ )
159131 path = Path (document .path )
160132 known_paths = PATHS_TO_MODELS .keys ()
161133 for context in CONTEXTS .values ():
162- if path .is_relative_to (context .path ) and path .suffix in (".sql" , ".py" ) and str (path ) not in known_paths :
134+ if (
135+ path .is_relative_to (context .path )
136+ and path .suffix in (".sql" , ".py" )
137+ and str (path ) not in known_paths
138+ ):
163139 ls .show_message (f"Refreshing context with new file: { path } " , types .MessageType .Info )
164140 async with C_MUTEX [context .path ]:
165141 await asyncio .to_thread (context .load )
166142 PATHS_TO_MODELS .update (
167- {str (model ._path .resolve ()): (context , weakref .proxy (model )) for model in context .models .values ()}
143+ {
144+ str (model ._path .resolve ()): (context , weakref .proxy (model ))
145+ for model in context .models .values ()
146+ }
168147 )
169148
170149
171150@server .feature (types .TEXT_DOCUMENT_DID_SAVE )
172- async def did_save (ls : LanguageServer , params : types .DidOpenTextDocumentParams ):
151+ async def did_save (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
173152 """Update diagnostics on document save."""
174- document = await ensure_context_for_document (ls .workspace .get_document (params .text_document .uri ))
153+ document = await ensure_context_for_document (
154+ ls .workspace .get_document (params .text_document .uri )
155+ )
175156 context , _ = PATHS_TO_MODELS .get (document .path , (None , None ))
176157 if context is not None :
177- context ._loader ._path_mtimes [Path (document .path )] = 0.0
178- async with C_MUTEX [context .path ]:
179- await asyncio .to_thread (context .load )
180- for model in context .models .values ():
181- if model ._path == Path (document .path ):
182- PATHS_TO_MODELS [document .path ] = (context , weakref .proxy (model ))
183- break
158+ for loader in context ._loaders :
159+ loader ._path_mtimes [Path (document .path )] = 0.0
160+ async with C_MUTEX [context .path ]:
161+ await asyncio .to_thread (context .load )
162+ for model in context .models .values ():
163+ if model ._path == Path (document .path ):
164+ PATHS_TO_MODELS [document .path ] = (context , weakref .proxy (model ))
165+ break
184166
185167
186168@server .feature (types .WORKSPACE_DID_CHANGE_WATCHED_FILES )
187- async def did_change_watched_files (ls : LanguageServer , params : types .DidChangeWatchedFilesParams ):
169+ async def did_change_watched_files (
170+ ls : LanguageServer , params : types .DidChangeWatchedFilesParams
171+ ) -> None :
188172 """Refresh context if a file changes."""
189- updated = {}
173+ updated : t . Dict [ t . Union [ str , Path ], bool ] = {}
190174 for change in params .changes :
191175 document = await ensure_context_for_document (ls .workspace .get_text_document (change .uri ))
192176 path = Path (document .path )
@@ -215,8 +199,11 @@ async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWa
215199 )
216200 updated [context .path ] = True
217201
218- def main ():
202+
203+ def main () -> None :
204+ logging .basicConfig (level = logging .DEBUG )
219205 server .start_io ()
220206
207+
221208if __name__ == "__main__" :
222209 main ()
0 commit comments