Skip to content

Commit 19df844

Browse files
Fix(lsp): Extend support for table references to when used with columns
1 parent c88daec commit 19df844

7 files changed

Lines changed: 275 additions & 34 deletions

File tree

examples/sushi/models/customers.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ LEFT JOIN (
3737
@ADD_ONE(1) AS another_column,
3838
FROM current_marketing_outer
3939
)
40-
SELECT * FROM current_marketing
40+
SELECT current_marketing.* FROM current_marketing WHERE current_marketing.customer_id != 100
4141
) AS m
4242
ON o.customer_id = m.customer_id
4343
LEFT JOIN raw.demographics AS d
4444
ON o.customer_id = d.customer_id
45+
WHERE sushi.orders.customer_id > 0

sqlmesh/lsp/reference.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,17 @@ def get_model_definitions_for_a_path(
208208
target_range=target_range,
209209
)
210210
)
211+
212+
column_references = _process_column_references(
213+
scope=scope,
214+
reference_name=table.name,
215+
read_file=read_file,
216+
referenced_model_uri=document_uri,
217+
description="",
218+
reference_type="cte",
219+
cte_target_range=target_range,
220+
)
221+
references.extend(column_references)
211222
continue
212223

213224
# For non-CTE tables, process as before (external model references)
@@ -276,6 +287,19 @@ def get_model_definitions_for_a_path(
276287
target_range=yaml_target_range,
277288
)
278289
)
290+
291+
column_references = _process_column_references(
292+
scope=scope,
293+
reference_name=normalized_reference_name,
294+
read_file=read_file,
295+
referenced_model_uri=referenced_model_uri,
296+
description=description,
297+
yaml_target_range=yaml_target_range,
298+
reference_type="external_model",
299+
default_catalog=lint_context.context.default_catalog,
300+
dialect=dialect,
301+
)
302+
references.extend(column_references)
279303
else:
280304
references.append(
281305
LSPModelReference(
@@ -288,6 +312,18 @@ def get_model_definitions_for_a_path(
288312
)
289313
)
290314

315+
column_references = _process_column_references(
316+
scope=scope,
317+
reference_name=normalized_reference_name,
318+
read_file=read_file,
319+
referenced_model_uri=referenced_model_uri,
320+
description=description,
321+
reference_type="model",
322+
default_catalog=lint_context.context.default_catalog,
323+
dialect=dialect,
324+
)
325+
references.extend(column_references)
326+
# breakpoint()
291327
return references
292328

293329

@@ -735,6 +771,104 @@ def _position_within_range(position: Position, range: Range) -> bool:
735771
)
736772

737773

774+
def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range:
775+
"""
776+
Get the range for a column's table reference, handling both simple and qualified table names.
777+
778+
Args:
779+
column: The column expression
780+
read_file: The file content as list of lines
781+
782+
Returns:
783+
The Range covering the table reference in the column
784+
"""
785+
786+
table_parts = column.parts[:-1] if len(column.parts) > 1 else [column.parts[0]]
787+
788+
start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file)
789+
end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file)
790+
791+
return Range(
792+
start=to_lsp_position(start_range.start),
793+
end=to_lsp_position(end_range.end),
794+
)
795+
796+
797+
def _process_column_references(
798+
scope: t.Any,
799+
reference_name: str,
800+
read_file: t.List[str],
801+
referenced_model_uri: URI,
802+
description: t.Optional[str] = None,
803+
yaml_target_range: t.Optional[Range] = None,
804+
reference_type: t.Literal["model", "external_model", "cte"] = "model",
805+
default_catalog: t.Optional[str] = None,
806+
dialect: t.Optional[str] = None,
807+
cte_target_range: t.Optional[Range] = None,
808+
) -> t.List[Reference]:
809+
"""
810+
Process column references for a given table and create appropriate reference objects.
811+
812+
Args:
813+
scope: The SQL scope to search for columns
814+
reference_name: The full reference name (may include database/catalog)
815+
read_file: The file content as list of lines
816+
referenced_model_uri: URI of the referenced model
817+
description: Markdown description for the reference
818+
yaml_target_range: Target range for external models (YAML files)
819+
reference_type: Type of reference - "model", "external_model", or "cte"
820+
default_catalog: Default catalog for normalization
821+
dialect: SQL dialect for normalization
822+
cte_target_range: Target range for CTE references
823+
824+
Returns:
825+
List of table references for column usages
826+
"""
827+
828+
references: t.List[Reference] = []
829+
for column in scope.find_all(exp.Column):
830+
if column.table:
831+
if reference_type == "cte":
832+
if column.table == reference_name:
833+
table_range = _get_column_table_range(column, read_file)
834+
references.append(
835+
LSPCteReference(
836+
uri=referenced_model_uri.value,
837+
range=table_range,
838+
target_range=cte_target_range,
839+
)
840+
)
841+
else:
842+
table_parts = [part.name for part in column.parts[:-1]]
843+
table_ref = ".".join(table_parts)
844+
normalized_reference_name = normalize_model_name(
845+
table_ref,
846+
default_catalog=default_catalog,
847+
dialect=dialect,
848+
)
849+
if normalized_reference_name == reference_name:
850+
table_range = _get_column_table_range(column, read_file)
851+
if reference_type == "external_model":
852+
references.append(
853+
LSPExternalModelReference(
854+
uri=referenced_model_uri.value,
855+
range=table_range,
856+
markdown_description=description,
857+
target_range=yaml_target_range,
858+
)
859+
)
860+
else:
861+
references.append(
862+
LSPModelReference(
863+
uri=referenced_model_uri.value,
864+
range=table_range,
865+
markdown_description=description,
866+
)
867+
)
868+
869+
return references
870+
871+
738872
def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
739873
"""
740874
Find the range of a specific model block in a YAML file.

tests/lsp/test_reference_cte_find_all.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ def test_cte_find_all_references():
2121

2222
# Test finding all references of "current_marketing"
2323
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
24-
assert len(ranges) == 2
24+
assert len(ranges) == 2 # regex finds 2 occurrences (definition and FROM clause)
2525

2626
# Click on the CTE definition
2727
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
2828
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
29-
30-
# Should find both the definition and the usage
31-
assert len(references) == 2
29+
# Should find the definition, FROM clause, and column prefix usages
30+
assert len(references) == 4 # definition + FROM + 2 column prefix uses
3231
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
3332

3433
reference_ranges = [ref.range for ref in references]
@@ -46,7 +45,7 @@ def test_cte_find_all_references():
4645
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
4746

4847
# Should find the same references
49-
assert len(references) == 2
48+
assert len(references) == 4 # definition + FROM + 2 column prefix uses
5049
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
5150

5251
reference_ranges = [ref.range for ref in references]
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from lsprotocol.types import Position
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_all_references, get_references, LSPModelReference
5+
from sqlmesh.lsp.uri import URI
6+
from tests.lsp.test_reference_cte import find_ranges_from_regex
7+
8+
9+
def test_model_reference_with_column_prefix():
10+
context = Context(paths=["examples/sushi"])
11+
lsp_context = LSPContext(context)
12+
13+
sushi_customers_path = next(
14+
path
15+
for path, info in lsp_context.map.items()
16+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
17+
)
18+
19+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
20+
read_file = file.readlines()
21+
22+
# Test finding references for "sushi.orders"
23+
ranges = find_ranges_from_regex(read_file, r"sushi\.orders")
24+
25+
# Click on the table reference in FROM clause (should be the second occurrence)
26+
from_clause_range = None
27+
for r in ranges:
28+
line_content = read_file[r.start.line].strip()
29+
if "FROM" in line_content:
30+
from_clause_range = r
31+
break
32+
33+
assert from_clause_range is not None, "Should find FROM clause with sushi.orders"
34+
35+
position = Position(
36+
line=from_clause_range.start.line, character=from_clause_range.start.character + 6
37+
)
38+
39+
model_refs = get_all_references(lsp_context, URI.from_path(sushi_customers_path), position)
40+
41+
assert len(model_refs) >= 7
42+
43+
# Verify that we have the FROM clause reference
44+
assert any(ref.range.start.line == from_clause_range.start.line for ref in model_refs), (
45+
"Should find FROM clause reference"
46+
)
47+
48+
49+
def test_column_prefix_references_are_found():
50+
context = Context(paths=["examples/sushi"])
51+
lsp_context = LSPContext(context)
52+
53+
sushi_customers_path = next(
54+
path
55+
for path, info in lsp_context.map.items()
56+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
57+
)
58+
59+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
60+
read_file = file.readlines()
61+
62+
# Find all occurrences of sushi.orders in the file
63+
ranges = find_ranges_from_regex(read_file, r"sushi\.orders")
64+
65+
# Should find exactly 2: FROM clause and WHERE clause with column prefix
66+
assert len(ranges) == 2, (
67+
f"Expected 2 occurrences of 'sushi.orders', found {len(ranges)}"
68+
)
69+
70+
# Verify we have the expected lines
71+
line_contents = [read_file[r.start.line].strip() for r in ranges]
72+
73+
# Should find FROM clause
74+
assert any("FROM sushi.orders" in content for content in line_contents), (
75+
"Should find FROM clause with sushi.orders"
76+
)
77+
78+
# Should find customer_id in WHERE clause with column prefix
79+
assert any("WHERE sushi.orders.customer_id" in content for content in line_contents), (
80+
"Should find WHERE clause with sushi.orders.customer_id"
81+
)

tests/lsp/test_reference_model_find_all.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def test_find_references_for_model_usages():
3030
# Click on the model reference
3131
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 6)
3232
references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position)
33-
assert len(references) >= 6, (
34-
f"Expected at least 6 references to sushi.orders, found {len(references)}"
33+
assert len(references) >= 7, (
34+
f"Expected at least 7 references to sushi.orders (including column prefix), found {len(references)}"
3535
)
3636

3737
# Verify expected files are present
@@ -50,44 +50,62 @@ def test_find_references_for_model_usages():
5050
)
5151

5252
# Verify exact ranges for each reference pattern
53+
# Note: customers file has multiple references due to column prefix support
5354
expected_ranges = {
54-
"orders": (0, 0, 0, 0), # the start for the model itself
55-
"customers": (30, 7, 30, 19),
56-
"waiter_revenue_by_day": (19, 5, 19, 17),
57-
"customer_revenue_lifetime": (38, 7, 38, 19),
58-
"customer_revenue_by_day": (33, 5, 33, 17),
59-
"latest_order": (12, 5, 12, 17),
55+
"orders": [(0, 0, 0, 0)], # the start for the model itself
56+
"customers": [(30, 7, 30, 19), (44, 6, 44, 18)], # FROM clause and WHERE clause
57+
"waiter_revenue_by_day": [(19, 5, 19, 17)],
58+
"customer_revenue_lifetime": [(38, 7, 38, 19)],
59+
"customer_revenue_by_day": [(33, 5, 33, 17)],
60+
"latest_order": [(12, 5, 12, 17)],
6061
}
6162

63+
# Group references by file pattern
64+
refs_by_pattern = {}
6265
for ref in references:
6366
matched_pattern = None
6467
for pattern in expected_patterns:
6568
if pattern in ref.uri:
6669
matched_pattern = pattern
6770
break
6871

69-
assert matched_pattern is not None, (
70-
f"Reference URI {ref.uri} doesn't match any expected pattern"
71-
)
72+
if matched_pattern:
73+
if matched_pattern not in refs_by_pattern:
74+
refs_by_pattern[matched_pattern] = []
75+
refs_by_pattern[matched_pattern].append(ref)
7276

73-
# Get expected range for this model
74-
expected_start_line, expected_start_char, expected_end_line, expected_end_char = (
75-
expected_ranges[matched_pattern]
76-
)
77+
# Verify each pattern has the expected references
78+
for pattern, expected_range_list in expected_ranges.items():
79+
assert pattern in refs_by_pattern, f"Missing references for pattern '{pattern}'"
7780

78-
# Assert exact range match
79-
assert ref.range.start.line == expected_start_line, (
80-
f"Expected {matched_pattern} reference start line {expected_start_line}, found {ref.range.start.line}"
81-
)
82-
assert ref.range.start.character == expected_start_char, (
83-
f"Expected {matched_pattern} reference start character {expected_start_char}, found {ref.range.start.character}"
81+
actual_refs = refs_by_pattern[pattern]
82+
assert len(actual_refs) == len(expected_range_list), (
83+
f"Expected {len(expected_range_list)} references for {pattern}, found {len(actual_refs)}"
8484
)
85-
assert ref.range.end.line == expected_end_line, (
86-
f"Expected {matched_pattern} reference end line {expected_end_line}, found {ref.range.end.line}"
87-
)
88-
assert ref.range.end.character == expected_end_char, (
89-
f"Expected {matched_pattern} reference end character {expected_end_char}, found {ref.range.end.character}"
85+
86+
# Sort both actual and expected by line number for consistent comparison
87+
actual_refs_sorted = sorted(
88+
actual_refs, key=lambda r: (r.range.start.line, r.range.start.character)
9089
)
90+
expected_sorted = sorted(expected_range_list, key=lambda r: (r[0], r[1]))
91+
92+
for i, (ref, expected_range) in enumerate(zip(actual_refs_sorted, expected_sorted)):
93+
expected_start_line, expected_start_char, expected_end_line, expected_end_char = (
94+
expected_range
95+
)
96+
97+
assert ref.range.start.line == expected_start_line, (
98+
f"Expected {pattern} reference #{i + 1} start line {expected_start_line}, found {ref.range.start.line}"
99+
)
100+
assert ref.range.start.character == expected_start_char, (
101+
f"Expected {pattern} reference #{i + 1} start character {expected_start_char}, found {ref.range.start.character}"
102+
)
103+
assert ref.range.end.line == expected_end_line, (
104+
f"Expected {pattern} reference #{i + 1} end line {expected_end_line}, found {ref.range.end.line}"
105+
)
106+
assert ref.range.end.character == expected_end_char, (
107+
f"Expected {pattern} reference #{i + 1} end character {expected_end_char}, found {ref.range.end.character}"
108+
)
91109

92110

93111
def test_find_references_for_marketing_model():

tests/test_forking.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,14 @@ def test_parallel_load(assert_exp_eq, mocker):
5555
"current_marketing"."status" AS "status",
5656
"current_marketing"."another_column" AS "another_column"
5757
FROM "current_marketing" AS "current_marketing"
58+
WHERE
59+
"current_marketing"."customer_id" <> 100
5860
) AS "m"
5961
ON "m"."customer_id" = "o"."customer_id"
6062
LEFT JOIN "memory"."raw"."demographics" AS "d"
6163
ON "d"."customer_id" = "o"."customer_id"
64+
WHERE
65+
"o"."customer_id" > 0
6266
""",
6367
)
6468

0 commit comments

Comments
 (0)