Skip to content

Commit 05ddba1

Browse files
committed
temp
1 parent fcc56ae commit 05ddba1

5 files changed

Lines changed: 58 additions & 136 deletions

File tree

sqlmesh/lsp/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def goto_definition(
280280
start=types.Position(line=0, character=0),
281281
end=types.Position(line=0, character=0),
282282
)
283-
283+
284284
location_links.append(
285285
types.LocationLink(
286286
target_uri=reference.uri,

sqlmesh/lsp/reference.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def get_model_definitions_for_a_path(
130130

131131
# Get SQL query and find all table references
132132
tables = list(query.find_all(exp.Table))
133-
133+
134134
with open(file_path, "r", encoding="utf-8") as file:
135135
read_file = file.readlines()
136136

@@ -139,33 +139,33 @@ def get_model_definitions_for_a_path(
139139
with_clause = query.find(exp.With)
140140
if with_clause:
141141
for cte in with_clause.expressions:
142-
if cte.alias and hasattr(cte, 'meta') and cte.meta:
142+
if cte.alias and hasattr(cte, "meta") and cte.meta:
143143
cte_definitions[cte.alias] = cte
144144

145145
for table in tables:
146146
table_name = table.name
147-
147+
148148
# Check if this table reference is a CTE
149149
if table_name in cte_definitions:
150150
# This is a CTE reference - create a reference to the CTE definition
151151
cte_def = cte_definitions[table_name]
152-
if hasattr(cte_def, 'meta') and cte_def.meta:
152+
if hasattr(cte_def, "meta") and cte_def.meta:
153153
try:
154154
# Get the position of the table reference
155155
table_meta = TokenPositionDetails.from_meta(table.this.meta)
156156
table_range = _range_from_token_position_details(table_meta, read_file)
157-
157+
158158
# Get the position of the CTE definition (alias part)
159159
cte_meta = TokenPositionDetails.from_meta(cte_def.meta)
160160
cte_range = _range_from_token_position_details(cte_meta, read_file)
161-
161+
162162
# Create a reference from the table usage to the CTE definition
163163
references.append(
164164
Reference(
165165
uri=document_uri.value, # Same file
166166
range=table_range,
167167
description=f"CTE: {table_name}",
168-
target_range=cte_range
168+
target_range=cte_range,
169169
)
170170
)
171171
except Exception:

tests/lsp/simple_cte_test.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

tests/lsp/test_cte_references.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

tests/lsp/test_reference_cte.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env python3
2+
"""Test script for CTE go-to-definition functionality."""
3+
4+
import re
5+
from sqlmesh.core.context import Context
6+
from sqlmesh.lsp.context import LSPContext, ModelTarget
7+
from sqlmesh.lsp.reference import get_references
8+
from sqlmesh.lsp.uri import URI
9+
from lsprotocol.types import Range, Position
10+
11+
def test_cte_parsing():
12+
context = Context(paths=["examples/sushi"])
13+
lsp_context = LSPContext(context)
14+
15+
# Find model URIs
16+
sushi_customers_path = next(
17+
path
18+
for path, info in lsp_context.map.items()
19+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
20+
)
21+
22+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
23+
read_file = file.readlines()
24+
25+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), Position(line=0, character=0))
26+
27+
assert len(references) == 1
28+
assert references[0].uri == URI.from_path(sushi_customers_path)
29+
30+
ranges = find_ranges_from_regex(read_file, r"WITH\s+current_marketing\s+AS\s+\(SELECT\s+customer_id,\s+status\s+FROM\s+sushi\.marketing\s+WHERE\s+valid_to\s+is\s+null\)")
31+
assert len(ranges) == 1
32+
assert ranges[0].start == Position(line=1, character=0)
33+
assert ranges[0].end == Position(line=1, character=1)
34+
35+
36+
37+
38+
39+
def find_ranges_from_regex(read_file: t.List[str], regex: str) -> t.List[Range]:
40+
"""Find all ranges in the read file that match the regex."""
41+
return [
42+
Range(
43+
start=Position(line=line_number, character=match.start()),
44+
end=Position(line=line_number, character=match.end())
45+
)
46+
for line_number, line in enumerate(read_file)
47+
for match in [m for m in [re.search(regex, line)] if m]
48+
]
49+
50+

0 commit comments

Comments
 (0)