Skip to content

Commit d5c8add

Browse files
committed
feat(lsp): add go to definition for ctes
1 parent 01114b5 commit d5c8add

4 files changed

Lines changed: 213 additions & 70 deletions

File tree

examples/sushi/models/customers.sql

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ CREATE VIEW raw.demographics AS (
1717
SELECT 1 AS customer_id, '00000' AS zip
1818
);
1919

20-
WITH current_marketing AS (
20+
WITH current_marketing_outer AS (
2121
SELECT
2222
customer_id,
2323
status
@@ -29,7 +29,15 @@ SELECT DISTINCT
2929
m.status,
3030
d.zip
3131
FROM sushi.orders AS o
32-
LEFT JOIN current_marketing AS m
32+
LEFT JOIN (
33+
WITH current_marketing AS (
34+
SELECT
35+
customer_id,
36+
status
37+
FROM current_marketing_outer
38+
)
39+
SELECT * FROM current_marketing
40+
) AS m
3341
ON o.customer_id = m.customer_id
3442
LEFT JOIN raw.demographics AS d
3543
ON o.customer_id = d.customer_id

sqlmesh/lsp/main.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,31 @@ def goto_definition(
265265
raise RuntimeError(f"No context found for document: {document.path}")
266266

267267
references = get_references(self.lsp_context, uri, params.position)
268-
return [
269-
types.LocationLink(
270-
target_uri=reference.uri,
271-
target_selection_range=types.Range(
268+
location_links = []
269+
for reference in references:
270+
# Use target_range if available (for CTEs), otherwise default to start of file
271+
if reference.target_range:
272+
target_range = reference.target_range
273+
target_selection_range = reference.target_range
274+
else:
275+
target_range = types.Range(
272276
start=types.Position(line=0, character=0),
273277
end=types.Position(line=0, character=0),
274-
),
275-
target_range=types.Range(
278+
)
279+
target_selection_range = types.Range(
276280
start=types.Position(line=0, character=0),
277281
end=types.Position(line=0, character=0),
278-
),
279-
origin_selection_range=reference.range,
282+
)
283+
284+
location_links.append(
285+
types.LocationLink(
286+
target_uri=reference.uri,
287+
target_selection_range=target_selection_range,
288+
target_range=target_range,
289+
origin_selection_range=reference.range,
290+
)
280291
)
281-
for reference in references
282-
]
292+
return location_links
283293
except Exception as e:
284294
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
285295
return []

sqlmesh/lsp/reference.py

Lines changed: 119 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,26 @@
55
from sqlmesh.core.model.definition import SqlModel
66
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
77
from sqlglot import exp
8+
from sqlglot.optimizer.scope import build_scope
89
from sqlmesh.lsp.uri import URI
910
from sqlmesh.utils.pydantic import PydanticModel
1011

1112

1213
class Reference(PydanticModel):
1314
"""
14-
A reference to a model.
15+
A reference to a model or CTE.
1516
1617
Attributes:
1718
range: The range of the reference in the source file
18-
uri: The uri of the referenced model
19-
description: The description of the referenced model
19+
uri: The uri of the referenced model or file
20+
description: The description of the referenced model or CTE
21+
target_range: The range of the definition for go-to-definition (optional, used for CTEs)
2022
"""
2123

2224
range: Range
2325
uri: str
2426
description: t.Optional[str] = None
27+
target_range: t.Optional[Range] = None
2528

2629

2730
def by_position(position: Position) -> t.Callable[[Reference], bool]:
@@ -87,6 +90,7 @@ def get_model_definitions_for_a_path(
8790
- Need to normalize it before matching
8891
- Try get_model before normalization
8992
- Match to models that the model refers to
93+
- Also find CTE references within the query
9094
"""
9195
path = document_uri.to_path()
9296
if path.suffix != ".sql":
@@ -125,64 +129,121 @@ def get_model_definitions_for_a_path(
125129
# Find all possible references
126130
references = []
127131

128-
# Get SQL query and find all table references
129-
tables = list(query.find_all(exp.Table))
130-
if len(tables) == 0:
131-
return []
132-
133132
with open(file_path, "r", encoding="utf-8") as file:
134133
read_file = file.readlines()
135134

136-
for table in tables:
137-
# Normalize the table reference
138-
unaliased = table.copy()
139-
if unaliased.args.get("alias") is not None:
140-
unaliased.set("alias", None)
141-
reference_name = unaliased.sql(dialect=dialect)
142-
try:
143-
normalized_reference_name = normalize_model_name(
144-
reference_name,
145-
default_catalog=lint_context.context.default_catalog,
146-
dialect=dialect,
147-
)
148-
if normalized_reference_name not in depends_on:
149-
continue
150-
except Exception:
151-
# Skip references that cannot be normalized
152-
continue
153-
154-
# Get the referenced model uri
155-
referenced_model = lint_context.context.get_model(
156-
model_or_snapshot=normalized_reference_name, raise_if_missing=False
157-
)
158-
if referenced_model is None:
159-
continue
160-
referenced_model_path = referenced_model._path
161-
# Check whether the path exists
162-
if not referenced_model_path.is_file():
163-
continue
164-
referenced_model_uri = URI.from_path(referenced_model_path)
165-
166-
# Extract metadata for positioning
167-
table_meta = TokenPositionDetails.from_meta(table.this.meta)
168-
table_range = _range_from_token_position_details(table_meta, read_file)
169-
start_pos = table_range.start
170-
end_pos = table_range.end
171-
172-
# If there's a catalog or database qualifier, adjust the start position
173-
catalog_or_db = table.args.get("catalog") or table.args.get("db")
174-
if catalog_or_db is not None:
175-
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
176-
catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file)
177-
start_pos = catalog_or_db_range.start
178-
179-
references.append(
180-
Reference(
181-
uri=referenced_model_uri.value,
182-
range=Range(start=start_pos, end=end_pos),
183-
description=referenced_model.description,
184-
)
185-
)
135+
# Build scope tree to properly handle nested CTEs
136+
root_scope = build_scope(query)
137+
138+
if root_scope:
139+
# Traverse all scopes to find CTE definitions and table references
140+
for scope in root_scope.traverse():
141+
# Build a map of CTE names to their definitions within this scope
142+
cte_definitions = {}
143+
144+
# For CTEs defined in this scope
145+
for cte in scope.ctes:
146+
if cte.alias:
147+
cte_definitions[cte.alias] = cte
148+
149+
# Also include CTEs from parent scopes (for references inside nested CTEs)
150+
parent = scope.parent
151+
while parent:
152+
for cte in parent.ctes:
153+
if cte.alias and cte.alias not in cte_definitions:
154+
cte_definitions[cte.alias] = cte
155+
parent = parent.parent
156+
157+
# Get all table references in this scope
158+
tables = list(scope.find_all(exp.Table))
159+
160+
for table in tables:
161+
table_name = table.name
162+
163+
# Check if this table reference is a CTE in the current scope
164+
if table_name in cte_definitions:
165+
try:
166+
# This is a CTE reference - create a reference to the CTE definition
167+
cte_def = cte_definitions[table_name]
168+
args = cte_def.args["alias"]
169+
if args and isinstance(args, exp.TableAlias):
170+
identifier = args.this
171+
if isinstance(identifier, exp.Identifier):
172+
meta = identifier.meta
173+
174+
table_meta_obj = TokenPositionDetails.from_meta(meta)
175+
target_range = _range_from_token_position_details(
176+
table_meta_obj, read_file
177+
)
178+
179+
table_meta_obj = TokenPositionDetails.from_meta(table.this.meta)
180+
table_range = _range_from_token_position_details(
181+
table_meta_obj, read_file
182+
)
183+
184+
references.append(
185+
Reference(
186+
uri=document_uri.value, # Same file
187+
range=table_range,
188+
target_range=target_range,
189+
)
190+
)
191+
except Exception:
192+
pass
193+
continue
194+
195+
# For non-CTE tables, process as before (external model references)
196+
# Normalize the table reference
197+
unaliased = table.copy()
198+
if unaliased.args.get("alias") is not None:
199+
unaliased.set("alias", None)
200+
reference_name = unaliased.sql(dialect=dialect)
201+
try:
202+
normalized_reference_name = normalize_model_name(
203+
reference_name,
204+
default_catalog=lint_context.context.default_catalog,
205+
dialect=dialect,
206+
)
207+
if normalized_reference_name not in depends_on:
208+
continue
209+
except Exception:
210+
# Skip references that cannot be normalized
211+
continue
212+
213+
# Get the referenced model uri
214+
referenced_model = lint_context.context.get_model(
215+
model_or_snapshot=normalized_reference_name, raise_if_missing=False
216+
)
217+
if referenced_model is None:
218+
continue
219+
referenced_model_path = referenced_model._path
220+
# Check whether the path exists
221+
if not referenced_model_path.is_file():
222+
continue
223+
referenced_model_uri = URI.from_path(referenced_model_path)
224+
225+
# Extract metadata for positioning
226+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
227+
table_range = _range_from_token_position_details(table_meta, read_file)
228+
start_pos = table_range.start
229+
end_pos = table_range.end
230+
231+
# If there's a catalog or database qualifier, adjust the start position
232+
catalog_or_db = table.args.get("catalog") or table.args.get("db")
233+
if catalog_or_db is not None:
234+
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
235+
catalog_or_db_range = _range_from_token_position_details(
236+
catalog_or_db_meta, read_file
237+
)
238+
start_pos = catalog_or_db_range.start
239+
240+
references.append(
241+
Reference(
242+
uri=referenced_model_uri.value,
243+
range=Range(start=start_pos, end=end_pos),
244+
description=referenced_model.description,
245+
)
246+
)
186247

187248
return references
188249

tests/lsp/test_reference_cte.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import re
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_references
5+
from sqlmesh.lsp.uri import URI
6+
from lsprotocol.types import Range, Position
7+
import typing as t
8+
9+
10+
def test_cte_parsing():
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Find model URIs
15+
sushi_customers_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
19+
)
20+
21+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
22+
read_file = file.readlines()
23+
24+
# Find position of the cte reference
25+
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
26+
assert len(ranges) == 2
27+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
28+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
29+
assert len(references) == 1
30+
assert references[0].uri == URI.from_path(sushi_customers_path).value
31+
assert references[0].description is None
32+
assert (
33+
references[0].range.start.line == ranges[1].start.line
34+
) # The reference location (where we clicked)
35+
assert (
36+
references[0].target_range.start.line == ranges[0].start.line
37+
) # The CTE definition location
38+
39+
# Find the position of the current_marketing_outer reference
40+
ranges = find_ranges_from_regex(read_file, r"current_marketing_outer")
41+
assert len(ranges) == 2
42+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
43+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
44+
assert len(references) == 1
45+
assert references[0].uri == URI.from_path(sushi_customers_path).value
46+
assert references[0].description is None
47+
assert (
48+
references[0].range.start.line == ranges[1].start.line
49+
) # The reference location (where we clicked)
50+
assert (
51+
references[0].target_range.start.line == ranges[0].start.line
52+
) # The CTE definition location
53+
54+
55+
def find_ranges_from_regex(read_file: t.List[str], regex: str) -> t.List[Range]:
56+
"""Find all ranges in the read file that match the regex."""
57+
return [
58+
Range(
59+
start=Position(line=line_number, character=match.start()),
60+
end=Position(line=line_number, character=match.end()),
61+
)
62+
for line_number, line in enumerate(read_file)
63+
for match in [m for m in [re.search(regex, line)] if m]
64+
]

0 commit comments

Comments
 (0)