Skip to content

Commit fcc56ae

Browse files
committed
feat(lsp): add go to definition for ctes
[ci skip]
1 parent 50802b3 commit fcc56ae

4 files changed

Lines changed: 194 additions & 16 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,31 @@ def goto_definition(
265265
raise RuntimeError(f"No context found for document: {document.path}")
266266

267267
references = get_references(self.lsp_context, uri, params.position)
268-
return [
269-
types.LocationLink(
270-
target_uri=reference.uri,
271-
target_selection_range=types.Range(
268+
location_links = []
269+
for reference in references:
270+
# Use target_range if available (for CTEs), otherwise default to start of file
271+
if reference.target_range:
272+
target_range = reference.target_range
273+
target_selection_range = reference.target_range
274+
else:
275+
target_range = types.Range(
272276
start=types.Position(line=0, character=0),
273277
end=types.Position(line=0, character=0),
274-
),
275-
target_range=types.Range(
278+
)
279+
target_selection_range = types.Range(
276280
start=types.Position(line=0, character=0),
277281
end=types.Position(line=0, character=0),
278-
),
279-
origin_selection_range=reference.range,
282+
)
283+
284+
location_links.append(
285+
types.LocationLink(
286+
target_uri=reference.uri,
287+
target_selection_range=target_selection_range,
288+
target_range=target_range,
289+
origin_selection_range=reference.range,
290+
)
280291
)
281-
for reference in references
282-
]
292+
return location_links
283293
except Exception as e:
284294
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
285295
return []

sqlmesh/lsp/reference.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@
1111

1212
class Reference(PydanticModel):
1313
"""
14-
A reference to a model.
14+
A reference to a model or CTE.
1515
1616
Attributes:
1717
range: The range of the reference in the source file
18-
uri: The uri of the referenced model
19-
description: The description of the referenced model
18+
uri: The uri of the referenced model or file
19+
description: The description of the referenced model or CTE
20+
target_range: The range of the definition for go-to-definition (optional, used for CTEs)
2021
"""
2122

2223
range: Range
2324
uri: str
2425
description: t.Optional[str] = None
26+
target_range: t.Optional[Range] = None
2527

2628

2729
def by_position(position: Position) -> t.Callable[[Reference], bool]:
@@ -87,6 +89,7 @@ def get_model_definitions_for_a_path(
8789
- Need to normalize it before matching
8890
- Try get_model before normalization
8991
- Match to models that the model refers to
92+
- Also find CTE references within the query
9093
"""
9194
path = document_uri.to_path()
9295
if path.suffix != ".sql":
@@ -127,13 +130,50 @@ def get_model_definitions_for_a_path(
127130

128131
# Get SQL query and find all table references
129132
tables = list(query.find_all(exp.Table))
130-
if len(tables) == 0:
131-
return []
132-
133+
133134
with open(file_path, "r", encoding="utf-8") as file:
134135
read_file = file.readlines()
135136

137+
# Build a map of CTE names to their definitions for CTE go-to-definition
138+
cte_definitions = {}
139+
with_clause = query.find(exp.With)
140+
if with_clause:
141+
for cte in with_clause.expressions:
142+
if cte.alias and hasattr(cte, 'meta') and cte.meta:
143+
cte_definitions[cte.alias] = cte
144+
136145
for table in tables:
146+
table_name = table.name
147+
148+
# Check if this table reference is a CTE
149+
if table_name in cte_definitions:
150+
# This is a CTE reference - create a reference to the CTE definition
151+
cte_def = cte_definitions[table_name]
152+
if hasattr(cte_def, 'meta') and cte_def.meta:
153+
try:
154+
# Get the position of the table reference
155+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
156+
table_range = _range_from_token_position_details(table_meta, read_file)
157+
158+
# Get the position of the CTE definition (alias part)
159+
cte_meta = TokenPositionDetails.from_meta(cte_def.meta)
160+
cte_range = _range_from_token_position_details(cte_meta, read_file)
161+
162+
# Create a reference from the table usage to the CTE definition
163+
references.append(
164+
Reference(
165+
uri=document_uri.value, # Same file
166+
range=table_range,
167+
description=f"CTE: {table_name}",
168+
target_range=cte_range
169+
)
170+
)
171+
except Exception:
172+
# Skip if we can't extract positioning info
173+
pass
174+
continue
175+
176+
# For non-CTE tables, process as before (external model references)
137177
# Normalize the table reference
138178
unaliased = table.copy()
139179
if unaliased.args.get("alias") is not None:

tests/lsp/simple_cte_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/usr/bin/env python3
2+
"""Simple test for CTE parsing with SQLGlot."""
3+
4+
import sqlglot
5+
from sqlglot import exp
6+
7+
def test_basic_cte():
8+
"""Test basic CTE functionality."""
9+
query_text = """WITH cte1 AS (SELECT 1 as id),
10+
cte2 AS (SELECT * FROM cte1)
11+
SELECT * FROM cte2"""
12+
13+
print("=== Basic CTE Test ===")
14+
print(f"Query: {query_text}")
15+
print()
16+
17+
# Parse the query
18+
parsed = sqlglot.parse_one(query_text)
19+
print(f"Parsed: {parsed}")
20+
print()
21+
22+
# Find WITH clause
23+
with_clause = parsed.find(exp.With)
24+
if with_clause:
25+
print("CTEs found:")
26+
cte_names = {}
27+
for cte in with_clause.expressions:
28+
print(f" - CTE alias: {cte.alias}")
29+
print(f" - CTE expression: {cte.this}")
30+
cte_names[cte.alias] = cte
31+
print()
32+
33+
# Find table references
34+
tables = list(parsed.find_all(exp.Table))
35+
print("Table references found:")
36+
for table in tables:
37+
table_name = table.name
38+
print(f" - Table name: {table_name}")
39+
if table_name in cte_names:
40+
print(f" -> This is a CTE reference to: {cte_names[table_name].alias}")
41+
else:
42+
print(f" -> This is likely an external table reference")
43+
print()
44+
45+
if __name__ == "__main__":
46+
test_basic_cte()

tests/lsp/test_cte_references.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env python3
2+
"""Test script for CTE go-to-definition functionality."""
3+
4+
from sqlglot import exp, parse_one
5+
from sqlmesh.lsp.reference import TokenPositionDetails, _range_from_token_position_details
6+
7+
def test_cte_parsing():
8+
"""Test CTE parsing with SQLGlot."""
9+
query_text = """
10+
WITH cte1 AS (SELECT 1 as id),
11+
cte2 AS (SELECT * FROM cte1)
12+
SELECT * FROM cte2
13+
"""
14+
15+
# Parse the query with positional metadata
16+
query = parse_one(query_text, read='')
17+
18+
print("=== Testing CTE Parsing ===")
19+
print(f"Query: {query_text.strip()}")
20+
print()
21+
22+
# Find WITH clause
23+
with_clause = query.find(exp.With)
24+
if with_clause:
25+
print("CTEs found:")
26+
for cte in with_clause.expressions:
27+
print(f" - CTE alias: {cte.alias}")
28+
print(f" - CTE meta: {cte.meta}")
29+
print(f" - CTE this: {cte.this}")
30+
print()
31+
32+
# Find table references
33+
tables = list(query.find_all(exp.Table))
34+
print("Table references found:")
35+
for table in tables:
36+
print(f" - Table name: {table.name}")
37+
print(f" - Table meta: {table.meta}")
38+
print(f" - Table this meta: {table.this.meta}")
39+
print()
40+
41+
def test_with_metadata():
42+
"""Test parsing with metadata to get position information."""
43+
query_text = """WITH cte1 AS (SELECT 1 as id),
44+
cte2 AS (SELECT * FROM cte1)
45+
SELECT * FROM cte2"""
46+
47+
print("=== Testing with Position Metadata ===")
48+
print(f"Query: {query_text}")
49+
print()
50+
51+
# Parse with metadata
52+
from sqlglot import transpile
53+
from sqlglot.parser import Parser
54+
55+
# Try to get metadata
56+
try:
57+
parsed = parse_one(query_text, read='')
58+
print("Parsed successfully with metadata")
59+
60+
# Check WITH clause
61+
with_clause = parsed.find(exp.With)
62+
if with_clause:
63+
for cte in with_clause.expressions:
64+
print(f"CTE {cte.alias}:")
65+
print(f" Meta: {cte.meta}")
66+
if hasattr(cte, 'args') and 'this' in cte.args:
67+
print(f" This meta: {cte.args['this'].meta}")
68+
print()
69+
70+
# Check tables
71+
for table in parsed.find_all(exp.Table):
72+
print(f"Table {table.name}:")
73+
print(f" Meta: {table.meta}")
74+
print(f" This meta: {table.this.meta}")
75+
print()
76+
77+
except Exception as e:
78+
print(f"Error: {e}")
79+
80+
if __name__ == "__main__":
81+
test_cte_parsing()
82+
test_with_metadata()

0 commit comments

Comments
 (0)