Skip to content

Commit f4f8006

Browse files
author
Christopher Giroir
committed
fix: PR fixes and many more tests
1 parent 0a2e6f4 commit f4f8006

3 files changed

Lines changed: 143 additions & 49 deletions

File tree

sqlmesh/lsp/hints.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@ def _get_type_hints_for_select(
7171

7272
for select_exp in expression.expressions:
7373
if isinstance(select_exp, exp.Alias):
74+
if isinstance(select_exp.this, exp.Cast):
75+
continue
76+
7477
meta = select_exp.args["alias"]._meta
78+
7579
elif isinstance(select_exp, exp.Column):
7680
meta = select_exp.parts[-1]._meta
7781
else:

tests/lsp/test_hints.py

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,8 @@ def test_hints() -> None:
5757
start_line=0,
5858
end_line=9999,
5959
)
60-
assert len(crbd_hints) == 4
60+
assert len(crbd_hints) == 1
6161
assert crbd_hints[0].label == "::INT"
62-
assert crbd_hints[1].label == "::DOUBLE"
63-
assert crbd_hints[2].label == "::INT"
64-
assert crbd_hints[3].label == "::DATE"
6562

6663

6764
@pytest.mark.fast
@@ -105,3 +102,102 @@ def test_complex_hints() -> None:
105102
assert len(result) == 2
106103
assert result[0].label == "::VARCHAR(100)"
107104
assert result[1].label == "::STRUCT<INT, STRUCT<TEXT, INT[]>>"
105+
106+
107+
@pytest.mark.fast
108+
def test_simple_cast_hints() -> None:
109+
"""Don't add type hints if the expression is already a cast"""
110+
query = parse_one("SELECT a::INT, CAST(b AS DATE), c FROM d", dialect="postgres")
111+
112+
result = _get_type_hints_for_model_from_query(
113+
query=query,
114+
dialect="postgres",
115+
columns_to_types={
116+
"a": exp.DataType.build("INT"),
117+
"b": exp.DataType.build("DATE"),
118+
"c": exp.DataType.build("TEXT"),
119+
},
120+
start_line=0,
121+
end_line=1,
122+
)
123+
124+
assert len(result) == 1
125+
assert result[0].label == "::TEXT"
126+
127+
128+
@pytest.mark.fast
129+
def test_alias_cast_hints() -> None:
130+
"""Don't add type hints if the expression is already a cast"""
131+
query = parse_one(
132+
"SELECT raw_a::INT AS a, CAST(raw_b AS DATE) AS b, c FROM d", dialect="postgres"
133+
)
134+
135+
result = _get_type_hints_for_model_from_query(
136+
query=query,
137+
dialect="postgres",
138+
columns_to_types={
139+
"a": exp.DataType.build("INT"),
140+
"b": exp.DataType.build("DATE"),
141+
"c": exp.DataType.build("TEXT"),
142+
},
143+
start_line=0,
144+
end_line=1,
145+
)
146+
147+
assert len(result) == 1
148+
assert result[0].label == "::TEXT"
149+
150+
151+
@pytest.mark.fast
152+
def test_simple_cte_hints() -> None:
153+
"""Don't add type hints if the expression is already a cast"""
154+
query = parse_one("WITH t AS (SELECT a FROM b) SELECT a AS c FROM t", dialect="postgres")
155+
156+
result = _get_type_hints_for_model_from_query(
157+
query=query,
158+
dialect="postgres",
159+
columns_to_types={
160+
"c": exp.DataType.build("INT"),
161+
},
162+
start_line=0,
163+
end_line=1,
164+
)
165+
166+
assert len(result) == 1
167+
assert result[0].label == "::INT"
168+
169+
170+
@pytest.mark.fast
171+
def test_cte_with_union_hints() -> None:
172+
"""Don't add type hints if the expression is already a cast"""
173+
query = parse_one(
174+
"""WITH x AS (SELECT a FROM t),
175+
y AS (SELECT b FROM t),
176+
z AS (SELECT c FROM t)
177+
SELECT a AS d FROM x
178+
UNION
179+
SELECT b AS e FROM y
180+
UNION
181+
SELECT c AS f FROM z""",
182+
dialect="postgres",
183+
)
184+
185+
result = _get_type_hints_for_model_from_query(
186+
query=query,
187+
dialect="postgres",
188+
columns_to_types={
189+
"a": exp.DataType.build("TEXT"),
190+
"b": exp.DataType.build("DATE"),
191+
"c": exp.DataType.build("INT"),
192+
"d": exp.DataType.build("TEXT"),
193+
"e": exp.DataType.build("DATE"),
194+
"f": exp.DataType.build("INT"),
195+
},
196+
start_line=0,
197+
end_line=9999,
198+
)
199+
200+
assert len(result) == 3
201+
assert result[0].label == "::TEXT"
202+
assert result[1].label == "::DATE"
203+
assert result[2].label == "::INT"

vscode/extension/tests/hints.spec.ts

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,43 @@ import os from 'os'
55
import { startVSCode, SUSHI_SOURCE_PATH } from './utils'
66

77
test('Model type hinting', async () => {
8-
const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'vscode-test-sushi-'))
9-
await fs.copy(SUSHI_SOURCE_PATH, tempDir)
10-
11-
try {
12-
const { window, close } = await startVSCode(tempDir)
13-
14-
// Wait for the models folder to be visible
15-
await window.waitForSelector('text=models')
16-
17-
// Click on the models folder
18-
await window
19-
.getByRole('treeitem', { name: 'models', exact: true })
20-
.locator('a')
21-
.click()
22-
23-
// Open the customers_revenue_by_day model
24-
await window
25-
.getByRole('treeitem', { name: 'customer_revenue_by_day.sql', exact: true })
26-
.locator('a')
27-
.click()
28-
29-
await window.waitForSelector('text=grain')
30-
await window.waitForSelector('text=Loaded SQLMesh Context')
31-
32-
// Wait a moment for hints to appear
33-
await window.waitForTimeout(500)
34-
35-
// Check if the hints are visible
36-
expect(
37-
await window.locator('text=customer_id::INT').count(),
38-
).toBe(1)
39-
expect(
40-
await window.locator('text=revenue::DOUBLE').count(),
41-
).toBe(1)
42-
expect(
43-
await window.locator('text="country code"::INT').count(),
44-
).toBe(1)
45-
expect(
46-
await window.locator('text=event_date::DATE').count(),
47-
).toBe(1)
48-
49-
await close()
50-
} finally {
51-
await fs.remove(tempDir)
52-
}
8+
const tempDir = await fs.mkdtemp(
9+
path.join(os.tmpdir(), 'vscode-test-sushi-'),
10+
)
11+
await fs.copy(SUSHI_SOURCE_PATH, tempDir)
12+
13+
try {
14+
const { window, close } = await startVSCode(tempDir)
15+
16+
// Wait for the models folder to be visible
17+
await window.waitForSelector('text=models')
18+
19+
// Click on the models folder
20+
await window
21+
.getByRole('treeitem', { name: 'models', exact: true })
22+
.locator('a')
23+
.click()
24+
25+
// Open the customers_revenue_by_day model
26+
await window
27+
.getByRole('treeitem', {
28+
name: 'customer_revenue_by_day.sql',
29+
exact: true,
30+
})
31+
.locator('a')
32+
.click()
33+
34+
await window.waitForSelector('text=grain')
35+
await window.waitForSelector('text=Loaded SQLMesh Context')
36+
37+
// Wait a moment for hints to appear
38+
await window.waitForTimeout(500)
39+
40+
// Check if the hint is visible
41+
expect(await window.locator('text="country code"::INT').count()).toBe(1)
42+
43+
await close()
44+
} finally {
45+
await fs.remove(tempDir)
46+
}
5347
})

0 commit comments

Comments
 (0)