Skip to content

Commit ff87b03

Browse files
author
Christopher Giroir
authored
feat: add inlay type hints for sqlmesh sql models (#4641)
1 parent 2595aca commit ff87b03

4 files changed

Lines changed: 406 additions & 0 deletions

File tree

sqlmesh/lsp/hints.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Type hinting on SQLMesh models"""
2+
3+
import typing as t
4+
5+
from lsprotocol import types
6+
7+
from sqlglot import exp
8+
from sqlglot.expressions import Expression
9+
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
10+
from sqlmesh.core.model.definition import SqlModel
11+
from sqlmesh.lsp.context import LSPContext, ModelTarget
12+
from sqlmesh.lsp.uri import URI
13+
14+
15+
def get_hints(
16+
lsp_context: LSPContext,
17+
document_uri: URI,
18+
start_line: int,
19+
end_line: int,
20+
) -> t.List[types.InlayHint]:
21+
"""
22+
Get type hints for certain lines in a document
23+
24+
Args:
25+
lint_context: The LSP context
26+
document_uri: The URI of the document
27+
start_line: the starting line to get hints for
28+
end_line: the ending line to get hints for
29+
30+
Returns:
31+
A list of hints to apply to the document
32+
"""
33+
path = document_uri.to_path()
34+
if path.suffix != ".sql":
35+
return []
36+
37+
if path not in lsp_context.map:
38+
return []
39+
40+
file_info = lsp_context.map[path]
41+
42+
# Process based on whether it's a model or standalone audit
43+
if not isinstance(file_info, ModelTarget):
44+
return []
45+
46+
# It's a model
47+
model = lsp_context.context.get_model(
48+
model_or_snapshot=file_info.names[0], raise_if_missing=False
49+
)
50+
if not isinstance(model, SqlModel):
51+
return []
52+
53+
query = model.query
54+
dialect = model.dialect
55+
columns_to_types = model.columns_to_types or {}
56+
57+
return _get_type_hints_for_model_from_query(
58+
query, dialect, columns_to_types, start_line, end_line
59+
)
60+
61+
62+
def _get_type_hints_for_select(
63+
expression: exp.Expression,
64+
dialect: str,
65+
columns_to_types: t.Dict[str, exp.DataType],
66+
start_line: int,
67+
end_line: int,
68+
) -> t.List[types.InlayHint]:
69+
hints: t.List[types.InlayHint] = []
70+
71+
for select_exp in expression.expressions:
72+
if isinstance(select_exp, exp.Alias):
73+
if isinstance(select_exp.this, exp.Cast):
74+
continue
75+
76+
meta = select_exp.args["alias"].meta
77+
78+
elif isinstance(select_exp, exp.Column):
79+
meta = select_exp.parts[-1].meta
80+
else:
81+
continue
82+
83+
if "line" not in meta or "col" not in meta:
84+
continue
85+
86+
line = meta["line"]
87+
col = meta["col"]
88+
89+
# Lines from sqlglot are 1 based
90+
line -= 1
91+
92+
if line < start_line or line > end_line:
93+
continue
94+
95+
name = select_exp.alias_or_name
96+
data_type = columns_to_types.get(name)
97+
98+
if not data_type or data_type.is_type(exp.DataType.Type.UNKNOWN):
99+
continue
100+
101+
type_label = data_type.sql(dialect)
102+
hints.append(
103+
types.InlayHint(
104+
label=f"::{type_label}",
105+
kind=types.InlayHintKind.Type,
106+
padding_left=False,
107+
padding_right=True,
108+
position=types.Position(line=line, character=col),
109+
)
110+
)
111+
112+
return hints
113+
114+
115+
def _get_type_hints_for_model_from_query(
116+
query: Expression,
117+
dialect: str,
118+
columns_to_types: t.Dict[str, exp.DataType],
119+
start_line: int,
120+
end_line: int,
121+
) -> t.List[types.InlayHint]:
122+
hints: t.List[types.InlayHint] = []
123+
try:
124+
query = normalize_identifiers(query.copy(), dialect=dialect)
125+
126+
# Return the hints for top level selects (model definition columns only)
127+
return [
128+
hint
129+
for q in query.walk(prune=lambda n: not isinstance(n, exp.SetOperation))
130+
if isinstance(select := q.unnest(), exp.Select)
131+
for hint in _get_type_hints_for_select(
132+
q, dialect, columns_to_types, start_line, end_line
133+
)
134+
]
135+
except Exception:
136+
return []

sqlmesh/lsp/main.py

100644100755
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
RenderModelRequest,
3737
RenderModelResponse,
3838
)
39+
from sqlmesh.lsp.hints import get_hints
3940
from sqlmesh.lsp.reference import get_references, get_cte_references
4041
from sqlmesh.lsp.uri import URI
4142
from web.server.api.endpoints.lineage import column_lineage, model_lineage
@@ -334,6 +335,25 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
334335
)
335336
return None
336337

338+
@self.server.feature(types.TEXT_DOCUMENT_INLAY_HINT)
339+
def inlay_hint(
340+
ls: LanguageServer, params: types.InlayHintParams
341+
) -> t.List[types.InlayHint]:
342+
"""Implement type hints for sql columns as inlay hints"""
343+
try:
344+
uri = URI(params.text_document.uri)
345+
self._ensure_context_for_document(uri)
346+
if self.lsp_context is None:
347+
raise RuntimeError(f"No context found for document: {uri}")
348+
349+
start_line = params.range.start.line
350+
end_line = params.range.end.line
351+
hints = get_hints(self.lsp_context, uri, start_line, end_line)
352+
return hints
353+
354+
except Exception as e:
355+
return []
356+
337357
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
338358
def goto_definition(
339359
ls: LanguageServer, params: types.DefinitionParams

tests/lsp/test_hints.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Tests for type hinting SQLMesh models"""
2+
3+
import pytest
4+
5+
from sqlglot import exp, parse_one
6+
7+
from sqlmesh.core.context import Context
8+
from sqlmesh.lsp.context import LSPContext, ModelTarget
9+
from sqlmesh.lsp.hints import get_hints, _get_type_hints_for_model_from_query
10+
from sqlmesh.lsp.uri import URI
11+
12+
13+
@pytest.mark.fast
14+
def test_hints() -> None:
15+
context = Context(paths=["examples/sushi"])
16+
lsp_context = LSPContext(context)
17+
18+
# Find model URIs
19+
active_customers_path = next(
20+
path
21+
for path, info in lsp_context.map.items()
22+
if isinstance(info, ModelTarget) and "sushi.active_customers" in info.names
23+
)
24+
customer_revenue_lifetime_path = next(
25+
path
26+
for path, info in lsp_context.map.items()
27+
if isinstance(info, ModelTarget) and "sushi.customer_revenue_lifetime" in info.names
28+
)
29+
customer_revenue_by_day_path = next(
30+
path
31+
for path, info in lsp_context.map.items()
32+
if isinstance(info, ModelTarget) and "sushi.customer_revenue_by_day" in info.names
33+
)
34+
35+
active_customers_uri = URI.from_path(active_customers_path)
36+
ac_hints = get_hints(lsp_context, active_customers_uri, start_line=0, end_line=9999)
37+
assert len(ac_hints) == 2
38+
assert ac_hints[0].label == "::INT"
39+
assert ac_hints[1].label == "::TEXT"
40+
41+
customer_revenue_lifetime_uri = URI.from_path(customer_revenue_lifetime_path)
42+
crl_hints = get_hints(
43+
lsp_context=lsp_context,
44+
document_uri=customer_revenue_lifetime_uri,
45+
start_line=0,
46+
end_line=9999,
47+
)
48+
assert len(crl_hints) == 3
49+
assert crl_hints[0].label == "::INT"
50+
assert crl_hints[1].label == "::DOUBLE"
51+
assert crl_hints[2].label == "::DATE"
52+
53+
customer_revenue_by_day_uri = URI.from_path(customer_revenue_by_day_path)
54+
crbd_hints = get_hints(
55+
lsp_context=lsp_context,
56+
document_uri=customer_revenue_by_day_uri,
57+
start_line=0,
58+
end_line=9999,
59+
)
60+
assert len(crbd_hints) == 1
61+
assert crbd_hints[0].label == "::INT"
62+
63+
64+
@pytest.mark.fast
65+
def test_union_hints() -> None:
66+
query_str = """SELECT a FROM table_a UNION SELECT b FROM table_b UNION SELECT c FROM table_c"""
67+
query = parse_one(query_str, dialect="postgres")
68+
69+
result = _get_type_hints_for_model_from_query(
70+
query=query,
71+
dialect="postgres",
72+
columns_to_types={
73+
"a": exp.DataType.build("TEXT"),
74+
"b": exp.DataType.build("INT"),
75+
"c": exp.DataType.build("DATE"),
76+
},
77+
start_line=0,
78+
end_line=1,
79+
)
80+
81+
assert len(result) == 3
82+
assert result[0].label == "::DATE"
83+
assert result[1].label == "::TEXT"
84+
assert result[2].label == "::INT"
85+
86+
87+
@pytest.mark.fast
88+
def test_complex_hints() -> None:
89+
query = parse_one("SELECT a, b FROM c", dialect="postgres")
90+
91+
result = _get_type_hints_for_model_from_query(
92+
query=query,
93+
dialect="postgres",
94+
columns_to_types={
95+
"a": exp.DataType.build("VARCHAR(100)"),
96+
"b": exp.DataType.build("STRUCT<INT, STRUCT<TEXT, ARRAY<INT>>>"),
97+
},
98+
start_line=0,
99+
end_line=1,
100+
)
101+
102+
assert len(result) == 2
103+
assert result[0].label == "::VARCHAR(100)"
104+
assert result[1].label == "::STRUCT<INT, STRUCT<TEXT, INT[]>>"
105+
106+
107+
@pytest.mark.fast
108+
def test_simple_cast_hints() -> None:
109+
"""Don't add type hints if the expression is already a cast"""
110+
query = parse_one("SELECT a::INT, CAST(b AS DATE), c FROM d", dialect="postgres")
111+
112+
result = _get_type_hints_for_model_from_query(
113+
query=query,
114+
dialect="postgres",
115+
columns_to_types={
116+
"a": exp.DataType.build("INT"),
117+
"b": exp.DataType.build("DATE"),
118+
"c": exp.DataType.build("TEXT"),
119+
},
120+
start_line=0,
121+
end_line=1,
122+
)
123+
124+
assert len(result) == 1
125+
assert result[0].label == "::TEXT"
126+
127+
128+
@pytest.mark.fast
129+
def test_alias_cast_hints() -> None:
130+
"""Don't add type hints if the expression is already a cast"""
131+
query = parse_one(
132+
"SELECT raw_a::INT AS a, CAST(raw_b AS DATE) AS b, c FROM d", dialect="postgres"
133+
)
134+
135+
result = _get_type_hints_for_model_from_query(
136+
query=query,
137+
dialect="postgres",
138+
columns_to_types={
139+
"a": exp.DataType.build("INT"),
140+
"b": exp.DataType.build("DATE"),
141+
"c": exp.DataType.build("TEXT"),
142+
},
143+
start_line=0,
144+
end_line=1,
145+
)
146+
147+
assert len(result) == 1
148+
assert result[0].label == "::TEXT"
149+
150+
151+
@pytest.mark.fast
152+
def test_simple_cte_hints() -> None:
153+
"""Don't add type hints if the expression is already a cast"""
154+
query = parse_one("WITH t AS (SELECT a FROM b) SELECT a AS c FROM t", dialect="postgres")
155+
156+
result = _get_type_hints_for_model_from_query(
157+
query=query,
158+
dialect="postgres",
159+
columns_to_types={
160+
"c": exp.DataType.build("INT"),
161+
},
162+
start_line=0,
163+
end_line=1,
164+
)
165+
166+
assert len(result) == 1
167+
assert result[0].label == "::INT"
168+
169+
170+
@pytest.mark.fast
171+
def test_cte_with_union_hints() -> None:
172+
"""Don't add type hints if the expression is already a cast"""
173+
query = parse_one(
174+
"""WITH x AS (SELECT a FROM t),
175+
y AS (SELECT b FROM t),
176+
z AS (SELECT c FROM t)
177+
SELECT a AS d FROM x
178+
UNION
179+
SELECT b AS e FROM y
180+
UNION
181+
SELECT c AS f FROM z""",
182+
dialect="postgres",
183+
)
184+
185+
result = _get_type_hints_for_model_from_query(
186+
query=query,
187+
dialect="postgres",
188+
columns_to_types={
189+
"a": exp.DataType.build("TEXT"),
190+
"b": exp.DataType.build("DATE"),
191+
"c": exp.DataType.build("INT"),
192+
"d": exp.DataType.build("TEXT"),
193+
"e": exp.DataType.build("DATE"),
194+
"f": exp.DataType.build("INT"),
195+
},
196+
start_line=0,
197+
end_line=9999,
198+
)
199+
200+
assert len(result) == 3
201+
assert result[0].label == "::INT"
202+
assert result[1].label == "::TEXT"
203+
assert result[2].label == "::DATE"

0 commit comments

Comments
 (0)