Skip to content

Commit 675d9a7

Browse files
committed
feat: introduce lsp
- supports formatting whole documents
1 parent 5d33825 commit 675d9a7

3 files changed

Lines changed: 103 additions & 1 deletion

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.PHONY: docs
22

33
install-dev:
4-
pip3 install -e ".[dev,web,slack,dlt]" ./examples/custom_materializations
4+
pip3 install -e ".[dev,web,slack,dlt,lsp]" ./examples/custom_materializations
55

66
install-doc:
77
pip3 install -r ./docs/requirements.txt

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,16 @@ web = [
119119
"sse-starlette>=0.2.2",
120120
"pyarrow",
121121
]
122+
lsp = [
123+
"pygls",
124+
"lsprotocol"
125+
]
122126
risingwave = ["psycopg2"]
123127

124128
[project.scripts]
125129
sqlmesh = "sqlmesh.cli.main:cli"
126130
sqlmesh_cicd = "sqlmesh.cicd.bot:bot"
131+
sqlmesh_lsp = "sqlmesh.lsp.main:main"
127132

128133
[project.entry-points."airflow.plugins"]
129134
sqlmesh_airflow = "sqlmesh.schedulers.airflow.plugin:SqlmeshAirflowPlugin"

sqlmesh/lsp/main.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#!/usr/bin/env python
2+
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
3+
4+
import logging
5+
import typing as t
6+
from contextlib import suppress
7+
from pathlib import Path
8+
9+
from sqlmesh.core.context import Context
10+
from lsprotocol import types
11+
from pygls.server import LanguageServer
12+
from pygls.workspace import TextDocument
13+
from sqlmesh._version import __version__
14+
from sqlmesh.core.model import Model
15+
16+
logger = logging.getLogger(__name__)
17+
18+
CONTEXTS: t.Dict[str, Context] = {}
19+
"""A mapping of workspace paths to SQLMesh contexts."""
20+
21+
PATHS_TO_MODELS: t.Dict[str, t.Tuple[Context, Model]] = {}
22+
"""A mapping of file paths to SQLMesh (context, model) tuples."""
23+
24+
server = LanguageServer("sqlmesh_lsp", __version__)
25+
26+
_CACHE: t.Set[str] = set()
27+
"""A cache of URIs for which we have already ensured a context exists."""
28+
29+
30+
def ensure_context_for_document(document: TextDocument) -> TextDocument:
31+
"""Ensure that a context exists for the given document if applicable by searching for a config.py or config.yml file in the parent directories."""
32+
if document.uri in _CACHE:
33+
return document
34+
_CACHE.add(document.uri)
35+
path = Path(document.path).resolve()
36+
if path.suffix not in (".sql", ".py"):
37+
return document
38+
initial_path = path
39+
while path.parents:
40+
if str(path) in CONTEXTS:
41+
return document
42+
path = path.parent
43+
path = initial_path
44+
loaded = False
45+
while path.parents and not loaded:
46+
for ext in ("py", "yml", "yaml"):
47+
config_path = path / f"config.{ext}"
48+
if config_path.exists():
49+
with suppress(Exception):
50+
handle = Context(paths=[f"{path}"])
51+
CONTEXTS[str(path)] = handle
52+
PATHS_TO_MODELS.update(
53+
{
54+
str(model._path.resolve()): (handle, model)
55+
for model in handle.models.values()
56+
}
57+
)
58+
server.show_message(f"Context loaded for: {path}")
59+
loaded = True
60+
break
61+
path = path.parent
62+
return document
63+
64+
65+
@server.feature(types.TEXT_DOCUMENT_FORMATTING)
66+
def formatting(
67+
ls: LanguageServer, params: types.DocumentFormattingParams
68+
) -> t.List[types.TextEdit]:
69+
"""Format the document based using SQLMesh format_model_expressions."""
70+
try:
71+
document = ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
72+
context, _ = PATHS_TO_MODELS.get(document.path, (None, None))
73+
if context is None:
74+
raise Exception(f"No context found for document: {document.path}")
75+
context.format(paths=(Path(document.path),))
76+
with open(document.path, "r+", encoding="utf-8") as file:
77+
return [
78+
types.TextEdit(
79+
range=types.Range(
80+
types.Position(0, 0),
81+
types.Position(len(document.lines), len(document.lines[-1])),
82+
),
83+
new_text=file.read(),
84+
)
85+
]
86+
except Exception as e:
87+
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
88+
return []
89+
90+
91+
def main() -> None:
92+
logging.basicConfig(level=logging.DEBUG)
93+
server.start_io()
94+
95+
96+
if __name__ == "__main__":
97+
main()

0 commit comments

Comments
 (0)