Skip to content

Commit 54c2ef4

Browse files
committed
feat: introduce lsp
- supports formatting whole documents
1 parent cc0d42b commit 54c2ef4

3 files changed

Lines changed: 109 additions & 2 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.PHONY: docs
22

33
install-dev:
4-
pip3 install -e ".[dev,web,slack,dlt]" ./examples/custom_materializations
4+
pip3 install -e ".[dev,web,slack,dlt,lsp]" ./examples/custom_materializations
55

66
install-doc:
77
pip3 install -r ./docs/requirements.txt

pyproject.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,19 @@ web = [
118118
"uvicorn[standard]==0.22.0",
119119
"sse-starlette>=0.2.2",
120120
"pyarrow",
121+
"lsprotocol",
122+
"pygls",
123+
]
124+
lsp = [
125+
"pygls",
126+
"lsprotocol"
121127
]
122128
risingwave = ["psycopg2"]
123129

124130
[project.scripts]
125131
sqlmesh = "sqlmesh.cli.main:cli"
126132
sqlmesh_cicd = "sqlmesh.cicd.bot:bot"
133+
sqlmesh_lsp = "sqlmesh.lsp.main:main"
127134

128135
[project.entry-points."airflow.plugins"]
129136
sqlmesh_airflow = "sqlmesh.schedulers.airflow.plugin:SqlmeshAirflowPlugin"
@@ -147,7 +154,7 @@ fallback_version = "0.0.0"
147154
local_scheme = "no-local-version"
148155

149156
[tool.setuptools.packages.find]
150-
include = ["sqlmesh", "sqlmesh.*", "web*"]
157+
include = ["sqlmesh", "sqlmesh.*", "web*"]
151158

152159
[tool.setuptools.package-data]
153160
web = ["client/dist/**"]

sqlmesh/lsp/main.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python
2+
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
3+
4+
import logging
5+
import typing as t
6+
import weakref
7+
from contextlib import suppress
8+
from pathlib import Path
9+
10+
from sqlmesh.core.context import Context
11+
from lsprotocol import types
12+
from pygls.server import LanguageServer
13+
from pygls.workspace import TextDocument
14+
from sqlmesh._version import __version__
15+
from sqlmesh.core.model import Model
16+
17+
logger = logging.getLogger(__name__)
18+
19+
CONTEXTS: t.Dict[str, Context] = {}
20+
"""A mapping of workspace paths to SQLMesh contexts."""
21+
22+
PATHS_TO_MODELS: t.Dict[str, t.Tuple[Context, Model]] = {}
23+
"""A mapping of file paths to SQLMesh (context, model) tuples."""
24+
25+
server = LanguageServer("sqlmesh_lsp", __version__)
26+
27+
_CACHE: t.Set[str] = set()
28+
"""A cache of URIs for which we have already ensured a context exists."""
29+
30+
31+
def ensure_context_for_document(document: TextDocument) -> TextDocument:
32+
"""Ensure that a context exists for the given document if applicable."""
33+
if document.uri in _CACHE:
34+
return document
35+
_CACHE.add(document.uri)
36+
path = Path(document.path).resolve()
37+
if path.suffix not in (".sql", ".py"):
38+
return document
39+
initial_path = path
40+
while path.parents:
41+
if str(path) in CONTEXTS:
42+
return document
43+
path = path.parent
44+
path = initial_path
45+
loaded = False
46+
while path.parents and not loaded:
47+
for ext in ("py", "yml", "yaml"):
48+
config_path = path / f"config.{ext}"
49+
if config_path.exists():
50+
with suppress(Exception):
51+
handle = Context(paths=[f"{path}"])
52+
CONTEXTS[str(path)] = handle
53+
PATHS_TO_MODELS.update(
54+
{
55+
str(model._path.resolve()): (handle, weakref.proxy(model))
56+
for model in handle.models.values()
57+
}
58+
)
59+
server.show_message(f"Context loaded for: {path}")
60+
loaded = True
61+
break
62+
path = path.parent
63+
return document
64+
65+
66+
@server.feature(types.TEXT_DOCUMENT_FORMATTING)
67+
def formatting(
68+
ls: LanguageServer, params: types.DocumentFormattingParams
69+
) -> t.List[types.TextEdit]:
70+
"""Format the document based using SQLMesh format_model_expressions."""
71+
try:
72+
logger.info(f"Formatting document: {params.text_document.uri}")
73+
document = ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
74+
context, _ = PATHS_TO_MODELS.get(document.path, (None, None))
75+
if context is None:
76+
logger.error(f"No context found for document: {document.path}")
77+
return []
78+
context.format(paths=(Path(document.path),))
79+
with open(document.path, "r+", encoding="utf-8") as file:
80+
return [
81+
types.TextEdit(
82+
range=types.Range(
83+
types.Position(0, 0),
84+
types.Position(len(document.lines), len(document.lines[-1])),
85+
),
86+
new_text=file.read(),
87+
)
88+
]
89+
except Exception as e:
90+
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
91+
return []
92+
93+
94+
def main() -> None:
95+
logging.basicConfig(level=logging.DEBUG)
96+
server.start_io()
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)