Skip to content

Commit b053969

Browse files
author
Christopher Giroir
committed
fix: handle union types
1 parent 66fcc32 commit b053969

2 files changed

Lines changed: 121 additions & 41 deletions

File tree

sqlmesh/lsp/hints.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from sqlglot import exp
88
from sqlglot.expressions import Expression
99
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
10-
from sqlglot.optimizer.scope import build_scope, find_all_in_scope
10+
from sqlglot.optimizer.scope import build_scope
1111
from sqlmesh.core.model.definition import SqlModel
1212
from sqlmesh.lsp.context import LSPContext, ModelTarget
1313
from sqlmesh.lsp.uri import URI
@@ -60,6 +60,77 @@ def get_hints(
6060
)
6161

6262

63+
def _get_type_hints_for_select(
64+
expression: exp.Select,
65+
dialect: str,
66+
columns_to_types: t.Dict[str, exp.DataType],
67+
start_line: int,
68+
end_line: int,
69+
) -> t.List[types.InlayHint]:
70+
hints: t.List[types.InlayHint] = []
71+
72+
for select_exp in expression.expressions:
73+
if isinstance(select_exp, exp.Alias):
74+
meta = select_exp.args["alias"]._meta
75+
elif isinstance(select_exp, exp.Column):
76+
meta = select_exp.parts[-1]._meta
77+
else:
78+
continue
79+
80+
if "line" not in meta or "col" not in meta:
81+
continue
82+
83+
line = meta["line"]
84+
col = meta["col"]
85+
86+
# Lines from sqlglot are 1 based
87+
line -= 1
88+
89+
if line < start_line or line > end_line:
90+
continue
91+
92+
name = select_exp.alias_or_name
93+
data_type = columns_to_types.get(name)
94+
95+
if not data_type or data_type.is_type(exp.DataType.Type.UNKNOWN):
96+
continue
97+
98+
type_label = data_type.sql(dialect)
99+
hints.append(
100+
types.InlayHint(
101+
label=f"::{type_label}",
102+
kind=types.InlayHintKind.Type,
103+
padding_left=False,
104+
padding_right=True,
105+
position=types.Position(line=line, character=col),
106+
)
107+
)
108+
109+
return hints
110+
111+
112+
def _get_type_hints_for_expression(
113+
expression: Expression,
114+
dialect: str,
115+
columns_to_types: t.Dict[str, exp.DataType],
116+
start_line: int,
117+
end_line: int,
118+
) -> t.List[types.InlayHint]:
119+
if isinstance(expression, exp.Union):
120+
return _get_type_hints_for_expression(
121+
expression.this, dialect, columns_to_types, start_line, end_line
122+
) + _get_type_hints_for_expression(
123+
expression.expression, dialect, columns_to_types, start_line, end_line
124+
)
125+
126+
if isinstance(expression, exp.Select):
127+
return _get_type_hints_for_select(
128+
expression, dialect, columns_to_types, start_line, end_line
129+
)
130+
131+
return []
132+
133+
63134
def _get_type_hints_for_model_from_query(
64135
query: Expression,
65136
dialect: str,
@@ -75,44 +146,8 @@ def _get_type_hints_for_model_from_query(
75146
if not root:
76147
return []
77148

78-
for select in find_all_in_scope(root.expression, exp.Select):
79-
for select_exp in select.expressions:
80-
if isinstance(select_exp, exp.Alias):
81-
meta = select_exp.args["alias"]._meta
82-
elif isinstance(select_exp, exp.Column):
83-
meta = select_exp.parts[-1]._meta
84-
else:
85-
continue
86-
87-
if "line" not in meta or "col" not in meta:
88-
continue
89-
90-
line = meta["line"]
91-
col = meta["col"]
92-
93-
# Lines from sqlglot are 1 based
94-
line -= 1
95-
96-
if line < start_line or line > end_line:
97-
continue
98-
99-
name = select_exp.alias_or_name
100-
data_type = columns_to_types.get(name)
101-
102-
if not data_type or data_type.is_type(exp.DataType.Type.UNKNOWN):
103-
continue
104-
105-
type_label = data_type.sql(dialect)
106-
hints.append(
107-
types.InlayHint(
108-
label=f"::{type_label}",
109-
kind=types.InlayHintKind.Type,
110-
padding_left=False,
111-
padding_right=True,
112-
position=types.Position(line=line, character=col),
113-
)
114-
)
115-
116-
return hints
149+
return _get_type_hints_for_expression(
150+
root.expression, dialect, columns_to_types, start_line, end_line
151+
)
117152
except Exception:
118153
return []

tests/lsp/test_hints.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import pytest
44

5+
from sqlglot import exp, parse_one
6+
57
from sqlmesh.core.context import Context
68
from sqlmesh.lsp.context import LSPContext, ModelTarget
7-
from sqlmesh.lsp.hints import get_hints
9+
from sqlmesh.lsp.hints import get_hints, _get_type_hints_for_model_from_query
810
from sqlmesh.lsp.uri import URI
911

1012

@@ -60,3 +62,46 @@ def test_hints() -> None:
6062
assert crbd_hints[1].label == "::DOUBLE"
6163
assert crbd_hints[2].label == "::INT"
6264
assert crbd_hints[3].label == "::DATE"
65+
66+
67+
@pytest.mark.fast
68+
def test_union_hints() -> None:
69+
query_str = """SELECT a FROM table_a UNION SELECT b FROM table_b UNION SELECT c FROM table_c"""
70+
query = parse_one(query_str, dialect="postgres")
71+
72+
result = _get_type_hints_for_model_from_query(
73+
query=query,
74+
dialect="postgres",
75+
columns_to_types={
76+
"a": exp.DataType.build("TEXT"),
77+
"b": exp.DataType.build("INT"),
78+
"c": exp.DataType.build("DATE"),
79+
},
80+
start_line=0,
81+
end_line=1,
82+
)
83+
84+
assert len(result) == 3
85+
assert result[0].label == "::TEXT"
86+
assert result[1].label == "::INT"
87+
assert result[2].label == "::DATE"
88+
89+
90+
@pytest.mark.fast
91+
def test_complex_hints() -> None:
92+
query = parse_one("SELECT a, b FROM c", dialect="postgres")
93+
94+
result = _get_type_hints_for_model_from_query(
95+
query=query,
96+
dialect="postgres",
97+
columns_to_types={
98+
"a": exp.DataType.build("VARCHAR(100)"),
99+
"b": exp.DataType.build("STRUCT<INT, STRUCT<TEXT, ARRAY<INT>>>"),
100+
},
101+
start_line=0,
102+
end_line=1,
103+
)
104+
105+
assert len(result) == 2
106+
assert result[0].label == "::VARCHAR(100)"
107+
assert result[1].label == "::STRUCT<INT, STRUCT<TEXT, INT[]>>"

0 commit comments

Comments
 (0)