Skip to content

Commit 9bc5a62

Browse files
author
Christopher Giroir
committed
chore: apply pr comments
1 parent f4f8006 commit 9bc5a62

2 files changed

Lines changed: 15 additions & 32 deletions

File tree

sqlmesh/lsp/hints.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_hints(
6161

6262

6363
def _get_type_hints_for_select(
64-
expression: exp.Select,
64+
expression: exp.Expression,
6565
dialect: str,
6666
columns_to_types: t.Dict[str, exp.DataType],
6767
start_line: int,
@@ -113,28 +113,6 @@ def _get_type_hints_for_select(
113113
return hints
114114

115115

116-
def _get_type_hints_for_expression(
117-
expression: Expression,
118-
dialect: str,
119-
columns_to_types: t.Dict[str, exp.DataType],
120-
start_line: int,
121-
end_line: int,
122-
) -> t.List[types.InlayHint]:
123-
if isinstance(expression, exp.Union):
124-
return _get_type_hints_for_expression(
125-
expression.this, dialect, columns_to_types, start_line, end_line
126-
) + _get_type_hints_for_expression(
127-
expression.expression, dialect, columns_to_types, start_line, end_line
128-
)
129-
130-
if isinstance(expression, exp.Select):
131-
return _get_type_hints_for_select(
132-
expression, dialect, columns_to_types, start_line, end_line
133-
)
134-
135-
return []
136-
137-
138116
def _get_type_hints_for_model_from_query(
139117
query: Expression,
140118
dialect: str,
@@ -150,8 +128,13 @@ def _get_type_hints_for_model_from_query(
150128
if not root:
151129
return []
152130

153-
return _get_type_hints_for_expression(
154-
root.expression, dialect, columns_to_types, start_line, end_line
155-
)
131+
return [
132+
hint
133+
for q in query.walk(prune=lambda n: not isinstance(n, exp.SetOperation))
134+
if isinstance(select := q.unnest(), exp.Select)
135+
for hint in _get_type_hints_for_select(
136+
q, dialect, columns_to_types, start_line, end_line
137+
)
138+
]
156139
except Exception:
157140
return []

tests/lsp/test_hints.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def test_union_hints() -> None:
7979
)
8080

8181
assert len(result) == 3
82-
assert result[0].label == "::TEXT"
83-
assert result[1].label == "::INT"
84-
assert result[2].label == "::DATE"
82+
assert result[0].label == "::DATE"
83+
assert result[1].label == "::TEXT"
84+
assert result[2].label == "::INT"
8585

8686

8787
@pytest.mark.fast
@@ -198,6 +198,6 @@ def test_cte_with_union_hints() -> None:
198198
)
199199

200200
assert len(result) == 3
201-
assert result[0].label == "::TEXT"
202-
assert result[1].label == "::DATE"
203-
assert result[2].label == "::INT"
201+
assert result[0].label == "::INT"
202+
assert result[1].label == "::TEXT"
203+
assert result[2].label == "::DATE"

0 commit comments

Comments
 (0)