Skip to content

Commit 94a2960

Browse files
authored
feat(vscode): add custom api call method (#4352)
1 parent 107f714 commit 94a2960

2 files changed

Lines changed: 86 additions & 0 deletions

File tree

sqlmesh/lsp/api.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
This module maps the LSP custom API calls to the SQLMesh web api.
3+
4+
Allowing the LSP to call the web api without having to know the details of the web api
5+
and thus passing through the details of the web api to the LSP, so that both the LSP
6+
and the web api can communicate with the same process, avoiding the need to have a
7+
separate process for the web api.
8+
"""
9+
10+
import typing as t
11+
from pydantic import field_validator
12+
from sqlmesh.utils.pydantic import PydanticModel
13+
from web.server.models import Model
14+
15+
API_FEATURE = "sqlmesh/api"
16+
17+
18+
class ApiRequest(PydanticModel):
19+
"""
20+
Request to call the SQLMesh API.
21+
This is a generic request that can be used to call any API endpoint.
22+
"""
23+
24+
requestId: str
25+
url: str
26+
method: t.Optional[str] = "GET"
27+
params: t.Optional[t.Dict[str, t.Any]] = None
28+
body: t.Optional[t.Dict[str, t.Any]] = None
29+
30+
31+
class ApiResponseGetModels(PydanticModel):
32+
"""
33+
Response from the SQLMesh API for the get_models endpoint.
34+
"""
35+
36+
data: t.List[Model]
37+
38+
@field_validator("data", mode="before")
39+
def sanitize_datetime_fields(cls, data: t.List[Model]) -> t.List[Model]:
40+
"""
41+
Convert datetime objects to None to avoid serialization issues.
42+
"""
43+
if isinstance(data, list):
44+
for model in data:
45+
if hasattr(model, "details") and model.details:
46+
# Convert datetime fields to None to avoid serialization issues
47+
for field in ["stamp", "start", "cron_prev", "cron_next"]:
48+
if (
49+
hasattr(model.details, field)
50+
and getattr(model.details, field) is not None
51+
):
52+
setattr(model.details, field, None)
53+
return data
54+
55+
56+
class ApiResponseGetLineage(PydanticModel):
57+
"""
58+
Response from the SQLMesh API for the get_lineage endpoint.
59+
"""
60+
61+
data: t.Dict[str, t.List[str]]

sqlmesh/lsp/main.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,20 @@
1111
from sqlmesh._version import __version__
1212
from sqlmesh.core.context import Context
1313
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
14+
from sqlmesh.lsp.api import (
15+
API_FEATURE,
16+
ApiRequest,
17+
ApiResponseGetLineage,
18+
ApiResponseGetModels,
19+
)
1420
from sqlmesh.lsp.completions import get_sql_completions
1521
from sqlmesh.lsp.context import LSPContext, ModelTarget
1622
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1723
from sqlmesh.lsp.reference import (
1824
get_references,
1925
)
26+
from web.server.api.endpoints.lineage import model_lineage
27+
from web.server.api.endpoints.models import get_models
2028

2129

2230
class SQLMeshLanguageServer:
@@ -79,6 +87,23 @@ def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsRespons
7987
except Exception as e:
8088
return get_sql_completions(None, params.textDocument.uri)
8189

90+
@self.server.feature(API_FEATURE)
91+
def api(
92+
ls: LanguageServer, request: ApiRequest
93+
) -> t.Union[ApiResponseGetModels, ApiResponseGetLineage]:
94+
ls.log_trace(f"API request: {request}")
95+
if self.lsp_context is None:
96+
raise RuntimeError("No context found")
97+
if request.url == "/api/models":
98+
response = ApiResponseGetModels(data=get_models(self.lsp_context.context))
99+
return response
100+
if request.url.startswith("/api/lineage"):
101+
name = request.url.split("/")[-1]
102+
lineage = model_lineage(name, self.lsp_context.context)
103+
non_set_lineage = {k: v for k, v in lineage.items() if v is not None}
104+
return ApiResponseGetLineage(data=non_set_lineage)
105+
raise NotImplementedError(f"API request not implemented: {request.url}")
106+
82107
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
83108
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
84109
context = self._context_get_or_load(params.text_document.uri)

0 commit comments

Comments
 (0)