Skip to content

Commit 0a584c8

Browse files
author
Christopher Giroir
committed
feat: add inlay type hints for sqlmesh sql models
1 parent a29ad26 commit 0a584c8

3 files changed

Lines changed: 198 additions & 3 deletions

File tree

sqlmesh/lsp/hints.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Type hinting on SQLMesh models"""
2+
3+
import typing as t
4+
5+
from lsprotocol import types
6+
7+
from sqlglot import exp
8+
from sqlglot.expressions import Expression
9+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
10+
from sqlglot.optimizer.scope import build_scope, find_all_in_scope
11+
from sqlmesh.core.model.definition import SqlModel
12+
from sqlmesh.lsp.context import LSPContext, ModelTarget
13+
from sqlmesh.lsp.uri import URI
14+
15+
16+
def get_hints(
17+
lsp_context: LSPContext,
18+
document_uri: URI,
19+
start_line: int,
20+
end_line: int,
21+
) -> t.List[types.InlayHint]:
22+
"""
23+
Get type hints for certain lines in a document
24+
25+
Args:
26+
lint_context: The LSP context
27+
document_uri: The URI of the document
28+
start_line: the starting line to get hints for
29+
end_line: the ending line to get hints for
30+
31+
Returns:
32+
A list of hints to apply to the document
33+
"""
34+
path = document_uri.to_path()
35+
if path.suffix != ".sql":
36+
return []
37+
38+
if path not in lsp_context.map:
39+
return []
40+
41+
file_info = lsp_context.map[path]
42+
43+
# Process based on whether it's a model or standalone audit
44+
if isinstance(file_info, ModelTarget):
45+
# It's a model
46+
model = lsp_context.context.get_model(
47+
model_or_snapshot=file_info.names[0], raise_if_missing=False
48+
)
49+
if model is None or not isinstance(model, SqlModel):
50+
return []
51+
52+
query = model.query
53+
dialect = model.dialect
54+
columns_to_types = model.columns_to_types or {}
55+
else:
56+
return []
57+
58+
return _get_type_hints_for_model_from_query(
59+
query, dialect, columns_to_types, start_line, end_line
60+
)
61+
62+
63+
def _get_type_hints_for_model_from_query(
64+
query: Expression,
65+
dialect: str,
66+
columns_to_types: t.Dict[str, t.Any],
67+
start_line: int,
68+
end_line: int,
69+
) -> t.List[types.InlayHint]:
70+
hints: t.List[types.InlayHint] = []
71+
try:
72+
query = normalize_identifiers(query.copy(), dialect=dialect)
73+
root = build_scope(query)
74+
75+
if not root:
76+
return []
77+
78+
for select in find_all_in_scope(root.expression, exp.Select):
79+
for select_exp in select.expressions:
80+
if not select_exp:
81+
continue
82+
83+
if isinstance(select_exp, exp.Alias):
84+
meta = select_exp.args["alias"]._meta
85+
elif isinstance(select_exp, exp.Column):
86+
meta = select_exp.parts[-1]._meta
87+
else:
88+
continue
89+
90+
line = meta.get("line") - 1 # Lines from sqlglot are 1 based
91+
col = meta.get("col")
92+
93+
name = select_exp.alias_or_name
94+
if name not in columns_to_types:
95+
continue
96+
97+
if line < start_line or line > end_line:
98+
continue
99+
100+
type_label = str(columns_to_types.get(name))
101+
hints.append(
102+
types.InlayHint(
103+
label=f"::{type_label}",
104+
kind=types.InlayHintKind.Type,
105+
padding_left=False,
106+
padding_right=True,
107+
position=types.Position(line=line, character=col),
108+
)
109+
)
110+
111+
return hints
112+
except Exception:
113+
return []

sqlmesh/lsp/main.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@
3636
RenderModelRequest,
3737
RenderModelResponse,
3838
)
39-
from sqlmesh.lsp.reference import (
40-
get_references,
41-
)
39+
from sqlmesh.lsp.hints import get_hints
40+
from sqlmesh.lsp.reference import get_references
4241
from sqlmesh.lsp.uri import URI
4342
from web.server.api.endpoints.lineage import column_lineage, model_lineage
4443
from web.server.api.endpoints.models import get_models
@@ -324,6 +323,27 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
324323
ls.show_message(f"Error getting hover information: {e}", types.MessageType.Error)
325324
return None
326325

326+
@self.server.feature(types.TEXT_DOCUMENT_INLAY_HINT)
327+
def inlay_hint(
328+
ls: LanguageServer, params: types.InlayHintParams
329+
) -> t.List[types.InlayHint]:
330+
"""Implement type hints for sql columns as inlay hints"""
331+
try:
332+
uri = URI(params.text_document.uri)
333+
self._ensure_context_for_document(uri)
334+
document = ls.workspace.get_text_document(params.text_document.uri)
335+
if self.lsp_context is None:
336+
raise RuntimeError(f"No context found for document: {document.path}")
337+
338+
start_line = params.range.start.line
339+
end_line = params.range.end.line
340+
hints = get_hints(self.lsp_context, uri, start_line, end_line)
341+
return hints
342+
343+
except Exception as e:
344+
ls.show_message(f"Error getting type hints: {e}", types.MessageType.Error)
345+
return []
346+
327347
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
328348
def goto_definition(
329349
ls: LanguageServer, params: types.DefinitionParams

tests/lsp/test_hints.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Tests for type hinting SQLMesh models"""
2+
3+
import pytest
4+
5+
from sqlmesh.core.context import Context
6+
from sqlmesh.lsp.context import LSPContext, ModelTarget
7+
from sqlmesh.lsp.hints import get_hints
8+
from sqlmesh.lsp.uri import URI
9+
10+
11+
@pytest.mark.fast
12+
def test_hints() -> None:
13+
context = Context(paths=["examples/sushi"])
14+
lsp_context = LSPContext(context)
15+
16+
# Find model URIs
17+
active_customers_path = next(
18+
path
19+
for path, info in lsp_context.map.items()
20+
if isinstance(info, ModelTarget) and "sushi.active_customers" in info.names
21+
)
22+
customer_revenue_lifetime_path = next(
23+
path
24+
for path, info in lsp_context.map.items()
25+
if isinstance(info, ModelTarget) and "sushi.customer_revenue_lifetime" in info.names
26+
)
27+
customer_revenue_by_day_path = next(
28+
path
29+
for path, info in lsp_context.map.items()
30+
if isinstance(info, ModelTarget) and "sushi.customer_revenue_by_day" in info.names
31+
)
32+
33+
active_customers_uri = URI.from_path(active_customers_path)
34+
ac_hints = get_hints(lsp_context, active_customers_uri, start_line=0, end_line=9999)
35+
assert len(ac_hints) == 2
36+
assert ac_hints[0].label == "::INT"
37+
assert ac_hints[1].label == "::TEXT"
38+
39+
customer_revenue_lifetime_uri = URI.from_path(customer_revenue_lifetime_path)
40+
crl_hints = get_hints(
41+
lsp_context=lsp_context,
42+
document_uri=customer_revenue_lifetime_uri,
43+
start_line=0,
44+
end_line=9999,
45+
)
46+
assert len(crl_hints) == 3
47+
assert crl_hints[0].label == "::INT"
48+
assert crl_hints[1].label == "::DOUBLE"
49+
assert crl_hints[2].label == "::DATE"
50+
51+
customer_revenue_by_day_uri = URI.from_path(customer_revenue_by_day_path)
52+
crbd_hints = get_hints(
53+
lsp_context=lsp_context,
54+
document_uri=customer_revenue_by_day_uri,
55+
start_line=0,
56+
end_line=9999,
57+
)
58+
assert len(crbd_hints) == 4
59+
assert crbd_hints[0].label == "::INT"
60+
assert crbd_hints[1].label == "::DOUBLE"
61+
assert crbd_hints[2].label == "::INT"
62+
assert crbd_hints[3].label == "::DATE"

0 commit comments

Comments
 (0)