Skip to content

Commit 5f25f0a

Browse files
committed
temp
1 parent 8bf7e8a commit 5f25f0a

1 file changed

Lines changed: 107 additions & 69 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 107 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,96 +1,134 @@
11
#!/usr/bin/env python
2-
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
2+
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

44
import itertools
55
import logging
66
import typing as t
77
from contextlib import suppress
88
from pathlib import Path
99

10-
from sqlmesh.core.audit.definition import ModelAudit
11-
from sqlmesh.core.context import Context
1210
from lsprotocol import types
1311
from pygls.server import LanguageServer
1412
from pygls.workspace import TextDocument
13+
1514
from sqlmesh._version import __version__
15+
from sqlmesh.core.audit.definition import ModelAudit
16+
from sqlmesh.core.context import Context
1617
from sqlmesh.core.model import Model
1718

18-
logger = logging.getLogger(__name__)
19-
20-
GLOBAL_CONTEXT: t.Optional[Context] = None
21-
FILE_MAP: t.Dict[str, t.Union[Model, ModelAudit]] = {}
2219

20+
class SQLMeshLanguageServer:
21+
def __init__(
22+
self,
23+
context_class: t.Type[Context],
24+
server_name: str = "sqlmesh_lsp",
25+
version: str = __version__,
26+
):
27+
"""
28+
:param context_class: A class that inherits from `Context`.
29+
:param server_name: Name for the language server.
30+
:param version: Version string.
31+
"""
32+
self.server = LanguageServer(server_name, version)
33+
self.context_class = context_class
34+
self.context: t.Optional[Context] = None
35+
self.file_map: t.Dict[str, t.Union[Model, ModelAudit]] = {}
36+
37+
# Register LSP features (e.g., formatting, hover, etc.)
38+
self._register_features()
39+
40+
def _register_features(self) -> None:
41+
"""Register LSP features on the internal LanguageServer instance."""
42+
43+
@self.server.feature(types.TEXT_DOCUMENT_FORMATTING)
44+
def formatting(
45+
ls: LanguageServer, params: types.DocumentFormattingParams
46+
) -> t.List[types.TextEdit]:
47+
"""Format the document using SQLMesh `format_model_expressions`."""
48+
try:
49+
document = self.ensure_context_for_document(
50+
ls.workspace.get_document(params.text_document.uri)
51+
)
2352

24-
server = LanguageServer("sqlmesh_lsp", __version__)
53+
if self.context is None:
54+
raise RuntimeError(f"No context found for document: {document.path}")
55+
56+
# Perform formatting using the loaded context
57+
self.context.format(paths=(Path(document.path),))
58+
with open(document.path, "r+", encoding="utf-8") as file:
59+
new_text = file.read()
60+
61+
# Return a single edit that replaces the entire file.
62+
return [
63+
types.TextEdit(
64+
range=types.Range(
65+
start=types.Position(line=0, character=0),
66+
end=types.Position(
67+
line=len(document.lines),
68+
character=len(document.lines[-1]) if document.lines else 0,
69+
),
70+
),
71+
new_text=new_text,
72+
)
73+
]
74+
except Exception as e:
75+
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
76+
return []
77+
78+
def ensure_context_for_document(self, document: TextDocument) -> TextDocument:
79+
"""
80+
Ensure that a context exists for the given document if applicable by searching
81+
for a config.py or config.yml file in the parent directories.
82+
"""
83+
# If the context is already loaded, check if this document belongs to it.
84+
if self.context is not None:
85+
self.context.load() # Reload or refresh context
86+
if document.uri in self.file_map:
87+
return document
2588

89+
# Try to match the document path with existing models/audits
90+
for model in itertools.chain(
91+
self.context._models.values(), self.context._audits.values()
92+
):
93+
if model._path is None:
94+
continue
95+
if model._path.resolve() == Path(document.path):
96+
self.file_map[document.uri] = model
97+
return document
98+
return document
2699

27-
def ensure_context_for_document(document: TextDocument) -> TextDocument:
28-
"""Ensure that a context exists for the given document if applicable by searching for a config.py or config.yml file in the parent directories."""
29-
# If the context is already loaded, return the document, if it is part of the same context
30-
global GLOBAL_CONTEXT, FILE_MAP
31-
if GLOBAL_CONTEXT is not None:
32-
GLOBAL_CONTEXT.load()
33-
if document.uri in FILE_MAP:
100+
# No context yet: try to find config and load it
101+
path = Path(document.path).resolve()
102+
if path.suffix not in (".sql", ".py"):
34103
return document
35-
for model in itertools.chain(
36-
GLOBAL_CONTEXT._models.values(), GLOBAL_CONTEXT._audits.values()
37-
):
38-
if model._path is None:
39-
continue
40-
path = model._path.resolve()
41-
if path == document.path:
42-
FILE_MAP[document.uri] = model
43-
return document
44-
return document
45104

46-
# If there is no context, load the context and then call this function again
47-
path = Path(document.path).resolve()
48-
if path.suffix not in (".sql", ".py"):
105+
loaded = False
106+
# Ascend directories to look for config
107+
while path.parents and not loaded:
108+
for ext in ("py", "yml", "yaml"):
109+
config_path = path / f"config.{ext}"
110+
if config_path.exists():
111+
with suppress(Exception):
112+
# Use user-provided instantiator to build the context
113+
self.context = self.context_class(paths=[path])
114+
self.server.show_message(f"Context loaded for: {path}")
115+
loaded = True
116+
# Re-check context for document now that it's loaded
117+
return self.ensure_context_for_document(document)
118+
path = path.parent
119+
49120
return document
50-
loaded = False
51-
while path.parents and not loaded:
52-
for ext in ("py", "yml", "yaml"):
53-
config_path = path / f"config.{ext}"
54-
if config_path.exists():
55-
with suppress(Exception):
56-
GLOBAL_CONTEXT = Context(paths=[path])
57-
server.show_message(f"Context loaded for: {path}")
58-
loaded = True
59-
return ensure_context_for_document(document)
60-
path = path.parent
61-
62-
return document
63-
64-
65-
@server.feature(types.TEXT_DOCUMENT_FORMATTING)
66-
def formatting(
67-
ls: LanguageServer, params: types.DocumentFormattingParams
68-
) -> t.List[types.TextEdit]:
69-
"""Format the document using SQLMesh format_model_expressions."""
70-
try:
71-
document = ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
72-
context = GLOBAL_CONTEXT
73-
if context is None:
74-
raise Exception(f"No context found for document: {document.path}")
75-
context.format(paths=(Path(document.path),))
76-
with open(document.path, "r+", encoding="utf-8") as file:
77-
return [
78-
types.TextEdit(
79-
range=types.Range(
80-
types.Position(0, 0),
81-
types.Position(len(document.lines), len(document.lines[-1])),
82-
),
83-
new_text=file.read(),
84-
)
85-
]
86-
except Exception as e:
87-
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
88-
return []
121+
122+
def start(self) -> None:
123+
"""Start the server with I/O transport."""
124+
logging.basicConfig(level=logging.DEBUG)
125+
self.server.start_io()
89126

90127

91128
def main() -> None:
92-
logging.basicConfig(level=logging.DEBUG)
93-
server.start_io()
129+
# Example instantiator that just uses the same signature as your original `Context` usage.
130+
sqlmesh_server = SQLMeshLanguageServer(context_class=Context)
131+
sqlmesh_server.start()
94132

95133

96134
if __name__ == "__main__":

0 commit comments

Comments
 (0)