Skip to content

Commit a830e4c

Browse files
Fix(lsp): Extend support for table references with columns (#4763)
1 parent 05c793c commit a830e4c

7 files changed

Lines changed: 406 additions & 34 deletions

File tree

examples/sushi/models/customers.sql

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@ LEFT JOIN (
3737
@ADD_ONE(1) AS another_column,
3838
FROM current_marketing_outer
3939
)
40-
SELECT * FROM current_marketing
40+
SELECT current_marketing.* FROM current_marketing WHERE current_marketing.customer_id != 100
4141
) AS m
4242
ON o.customer_id = m.customer_id
4343
LEFT JOIN raw.demographics AS d
4444
ON o.customer_id = d.customer_id
45+
WHERE sushi.orders.customer_id > 0

sqlmesh/lsp/reference.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,17 @@ def get_model_definitions_for_a_path(
208208
target_range=target_range,
209209
)
210210
)
211+
212+
column_references = _process_column_references(
213+
scope=scope,
214+
reference_name=table.name,
215+
read_file=read_file,
216+
referenced_model_uri=document_uri,
217+
description="",
218+
reference_type="cte",
219+
cte_target_range=target_range,
220+
)
221+
references.extend(column_references)
211222
continue
212223

213224
# For non-CTE tables, process as before (external model references)
@@ -276,6 +287,19 @@ def get_model_definitions_for_a_path(
276287
target_range=yaml_target_range,
277288
)
278289
)
290+
291+
column_references = _process_column_references(
292+
scope=scope,
293+
reference_name=normalized_reference_name,
294+
read_file=read_file,
295+
referenced_model_uri=referenced_model_uri,
296+
description=description,
297+
yaml_target_range=yaml_target_range,
298+
reference_type="external_model",
299+
default_catalog=lint_context.context.default_catalog,
300+
dialect=dialect,
301+
)
302+
references.extend(column_references)
279303
else:
280304
references.append(
281305
LSPModelReference(
@@ -288,6 +312,18 @@ def get_model_definitions_for_a_path(
288312
)
289313
)
290314

315+
column_references = _process_column_references(
316+
scope=scope,
317+
reference_name=normalized_reference_name,
318+
read_file=read_file,
319+
referenced_model_uri=referenced_model_uri,
320+
description=description,
321+
reference_type="model",
322+
default_catalog=lint_context.context.default_catalog,
323+
dialect=dialect,
324+
)
325+
references.extend(column_references)
326+
291327
return references
292328

293329

@@ -735,6 +771,104 @@ def _position_within_range(position: Position, range: Range) -> bool:
735771
)
736772

737773

774+
def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range:
775+
"""
776+
Get the range for a column's table reference, handling both simple and qualified table names.
777+
778+
Args:
779+
column: The column expression
780+
read_file: The file content as list of lines
781+
782+
Returns:
783+
The Range covering the table reference in the column
784+
"""
785+
786+
table_parts = column.parts[:-1]
787+
788+
start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file)
789+
end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file)
790+
791+
return Range(
792+
start=to_lsp_position(start_range.start),
793+
end=to_lsp_position(end_range.end),
794+
)
795+
796+
797+
def _process_column_references(
798+
scope: t.Any,
799+
reference_name: str,
800+
read_file: t.List[str],
801+
referenced_model_uri: URI,
802+
description: t.Optional[str] = None,
803+
yaml_target_range: t.Optional[Range] = None,
804+
reference_type: t.Literal["model", "external_model", "cte"] = "model",
805+
default_catalog: t.Optional[str] = None,
806+
dialect: t.Optional[str] = None,
807+
cte_target_range: t.Optional[Range] = None,
808+
) -> t.List[Reference]:
809+
"""
810+
Process column references for a given table and create appropriate reference objects.
811+
812+
Args:
813+
scope: The SQL scope to search for columns
814+
reference_name: The full reference name (may include database/catalog)
815+
read_file: The file content as list of lines
816+
referenced_model_uri: URI of the referenced model
817+
description: Markdown description for the reference
818+
yaml_target_range: Target range for external models (YAML files)
819+
reference_type: Type of reference - "model", "external_model", or "cte"
820+
default_catalog: Default catalog for normalization
821+
dialect: SQL dialect for normalization
822+
cte_target_range: Target range for CTE references
823+
824+
Returns:
825+
List of table references for column usages
826+
"""
827+
828+
references: t.List[Reference] = []
829+
for column in scope.find_all(exp.Column):
830+
if column.table:
831+
if reference_type == "cte":
832+
if column.table == reference_name:
833+
table_range = _get_column_table_range(column, read_file)
834+
references.append(
835+
LSPCteReference(
836+
uri=referenced_model_uri.value,
837+
range=table_range,
838+
target_range=cte_target_range,
839+
)
840+
)
841+
else:
842+
table_parts = [part.sql(dialect) for part in column.parts[:-1]]
843+
table_ref = ".".join(table_parts)
844+
normalized_reference_name = normalize_model_name(
845+
table_ref,
846+
default_catalog=default_catalog,
847+
dialect=dialect,
848+
)
849+
if normalized_reference_name == reference_name:
850+
table_range = _get_column_table_range(column, read_file)
851+
if reference_type == "external_model":
852+
references.append(
853+
LSPExternalModelReference(
854+
uri=referenced_model_uri.value,
855+
range=table_range,
856+
markdown_description=description,
857+
target_range=yaml_target_range,
858+
)
859+
)
860+
else:
861+
references.append(
862+
LSPModelReference(
863+
uri=referenced_model_uri.value,
864+
range=table_range,
865+
markdown_description=description,
866+
)
867+
)
868+
869+
return references
870+
871+
738872
def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
739873
"""
740874
Find the range of a specific model block in a YAML file.

tests/lsp/test_reference_cte_find_all.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@ def test_cte_find_all_references():
2121

2222
# Test finding all references of "current_marketing"
2323
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
24-
assert len(ranges) == 2
24+
assert len(ranges) == 2 # regex finds 2 occurrences (definition and FROM clause)
2525

2626
# Click on the CTE definition
2727
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
2828
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
29-
30-
# Should find both the definition and the usage
31-
assert len(references) == 2
29+
# Should find the definition, FROM clause, and column prefix usages
30+
assert len(references) == 4 # definition + FROM + 2 column prefix uses
3231
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
3332

3433
reference_ranges = [ref.range for ref in references]
@@ -46,7 +45,7 @@ def test_cte_find_all_references():
4645
references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position)
4746

4847
# Should find the same references
49-
assert len(references) == 2
48+
assert len(references) == 4 # definition + FROM + 2 column prefix uses
5049
assert all(ref.uri == URI.from_path(sushi_customers_path).value for ref in references)
5150

5251
reference_ranges = [ref.range for ref in references]

0 commit comments

Comments
 (0)