11#!/usr/bin/env python
22"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
33
4- import asyncio
5- import gc
64import logging
75import typing as t
86import weakref
9- from collections import defaultdict
107from contextlib import suppress
11- from itertools import cycle
128from pathlib import Path
139
14- import sqlmesh
10+ from sqlmesh . core . context import Context
1511from lsprotocol import types
1612from pygls .server import LanguageServer
1713from pygls .workspace import TextDocument
1814from sqlmesh ._version import __version__
15+ from sqlmesh .core .model import Model
1916
2017logger = logging .getLogger (__name__ )
2118
22- CONTEXTS : t .Dict [str , sqlmesh . Context ] = {}
19+ CONTEXTS : t .Dict [str , Context ] = {}
2320"""A mapping of workspace paths to SQLMesh contexts."""
2421
25- PATHS_TO_MODELS : t .Dict [str , t .Tuple [sqlmesh . Context , sqlmesh . Model ]] = {}
22+ PATHS_TO_MODELS : t .Dict [str , t .Tuple [Context , Model ]] = {}
2623"""A mapping of file paths to SQLMesh (context, model) tuples."""
2724
28- C_MUTEX : t .DefaultDict [t .Union [str , Path ], asyncio .Lock ] = defaultdict (asyncio .Lock )
29- """A locking mechanism for ensuring that context mutation is thread-safe."""
25+ server = LanguageServer ("sqlmesh_lsp" , __version__ )
3026
31- loop = asyncio .get_event_loop ()
32-
33- server = LanguageServer ("sqlmesh_lsp" , __version__ , loop = loop )
34-
35-
36- async def refresh_context_loop (context : sqlmesh .Context ) -> None :
37- """Refresh the SQLMesh context every 5 seconds.
38- SQLMesh already ensures that the context is only refreshed when necessary so this
39- is efficient and safe to do, even if the context is large. Mtime checks are used.
40- """
41- gc_iter = cycle (list (range (10 )))
42- while True :
43- await asyncio .sleep (10.0 )
44- for loader in context ._loaders :
45- if loader .reload_needed ():
46- async with C_MUTEX [context .path ]:
47- await asyncio .to_thread (context .load )
48- PATHS_TO_MODELS .update (
49- {
50- str (model ._path .resolve ()): (context , weakref .proxy (model ))
51- for model in context .models .values ()
52- }
53- )
54- if next (gc_iter ) == 0 :
55- gc .collect ()
56-
57-
58- _CACHE = set ()
27+ _CACHE : t .Set [str ] = set ()
5928"""A cache of URIs for which we have already ensured a context exists."""
6029
6130
62- async def ensure_context_for_document (document : TextDocument ) -> TextDocument :
31+ def ensure_context_for_document (document : TextDocument ) -> TextDocument :
6332 """Ensure that a context exists for the given document if applicable."""
6433 if document .uri in _CACHE :
6534 return document
@@ -79,8 +48,7 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
7948 config_path = path / f"config.{ ext } "
8049 if config_path .exists ():
8150 with suppress (Exception ):
82- handle = sqlmesh .Context (paths = path )
83- loop .create_task (refresh_context_loop (handle ))
51+ handle = Context (paths = [f"{ path } " ])
8452 CONTEXTS [str (path )] = handle
8553 PATHS_TO_MODELS .update (
8654 {
@@ -96,17 +64,18 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
9664
9765
9866@server .feature (types .TEXT_DOCUMENT_FORMATTING )
99- async def formatting (
67+ def formatting (
10068 ls : LanguageServer , params : types .DocumentFormattingParams
10169) -> t .List [types .TextEdit ]:
10270 """Format the document based using SQLMesh format_model_expressions."""
10371 try :
10472 logger .info (f"Formatting document: { params .text_document .uri } " )
105- document = await ensure_context_for_document (
106- ls .workspace .get_document (params .text_document .uri )
107- )
108- context , model = PATHS_TO_MODELS .get (document .path , (None , None ))
109- context .format (paths = [Path (document .path )])
73+ document = ensure_context_for_document (ls .workspace .get_document (params .text_document .uri ))
74+ context , _ = PATHS_TO_MODELS .get (document .path , (None , None ))
75+ if context is None :
76+ logger .error (f"No context found for document: { document .path } " )
77+ return []
78+ context .format (paths = (Path (document .path ),))
11079 with open (document .path , "r+" , encoding = "utf-8" ) as file :
11180 return [
11281 types .TextEdit (
@@ -122,84 +91,6 @@ async def formatting(
12291 return []
12392
12493
125- @server .feature (types .TEXT_DOCUMENT_DID_OPEN )
126- async def did_open (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
127- """Update diagnostics on document open and refresh context if necessary."""
128- document = await ensure_context_for_document (
129- ls .workspace .get_document (params .text_document .uri )
130- )
131- path = Path (document .path )
132- known_paths = PATHS_TO_MODELS .keys ()
133- for context in CONTEXTS .values ():
134- if (
135- path .is_relative_to (context .path )
136- and path .suffix in (".sql" , ".py" )
137- and str (path ) not in known_paths
138- ):
139- ls .show_message (f"Refreshing context with new file: { path } " , types .MessageType .Info )
140- async with C_MUTEX [context .path ]:
141- await asyncio .to_thread (context .load )
142- PATHS_TO_MODELS .update (
143- {
144- str (model ._path .resolve ()): (context , weakref .proxy (model ))
145- for model in context .models .values ()
146- }
147- )
148-
149-
150- @server .feature (types .TEXT_DOCUMENT_DID_SAVE )
151- async def did_save (ls : LanguageServer , params : types .DidOpenTextDocumentParams ) -> None :
152- """Update diagnostics on document save."""
153- document = await ensure_context_for_document (
154- ls .workspace .get_document (params .text_document .uri )
155- )
156- context , _ = PATHS_TO_MODELS .get (document .path , (None , None ))
157- if context is not None :
158- for loader in context ._loaders :
159- loader ._path_mtimes [Path (document .path )] = 0.0
160- async with C_MUTEX [context .path ]:
161- await asyncio .to_thread (context .load )
162- for model in context .models .values ():
163- if model ._path == Path (document .path ):
164- PATHS_TO_MODELS [document .path ] = (context , weakref .proxy (model ))
165- break
166-
167-
168- @server .feature (types .WORKSPACE_DID_CHANGE_WATCHED_FILES )
169- async def did_change_watched_files (
170- ls : LanguageServer , params : types .DidChangeWatchedFilesParams
171- ) -> None :
172- """Refresh context if a file changes."""
173- updated : t .Dict [t .Union [str , Path ], bool ] = {}
174- for change in params .changes :
175- document = await ensure_context_for_document (ls .workspace .get_text_document (change .uri ))
176- path = Path (document .path )
177- known_paths = PATHS_TO_MODELS .keys ()
178- if change .type == types .FileChangeType .Deleted and str (path ) in known_paths :
179- # We don't need to refresh the context if a file is deleted, we just remove it from the cache
180- del PATHS_TO_MODELS [str (path )]
181- continue
182- for context in CONTEXTS .values ():
183- # If a new file is created, we need to force reload the appropriate context
184- if (
185- path .is_relative_to (context .path )
186- and path .suffix in (".sql" , ".py" )
187- and str (path ) not in known_paths
188- and change .type == types .FileChangeType .Created
189- and not updated .get (context .path , False )
190- ):
191- ls .show_message (f"Refreshing context with new file: { path } " , types .MessageType .Info )
192- async with C_MUTEX [context .path ]:
193- await asyncio .to_thread (context .load )
194- PATHS_TO_MODELS .update (
195- {
196- str (model ._path .resolve ()): (context , weakref .proxy (model ))
197- for model in context .models .values ()
198- }
199- )
200- updated [context .path ] = True
201-
202-
20394def main () -> None :
20495 logging .basicConfig (level = logging .DEBUG )
20596 server .start_io ()
0 commit comments