Skip to content

Commit 9ae20ff

Browse files
committed
feat(vscode): go to definition for standalone audits
1 parent 7530c3c commit 9ae20ff

7 files changed

Lines changed: 171 additions & 42 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,6 @@ tests/_version.py
155155
# spark
156156
metastore_db/
157157
spark-warehouse/
158+
159+
# claude
160+
.claude/

sqlmesh/lsp/completions.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from sqlglot import Dialect, Tokenizer
33
from sqlmesh.lsp.custom import AllModelsResponse
44
import typing as t
5-
from sqlmesh.lsp.context import LSPContext
5+
from sqlmesh.lsp.context import LSPContext, ModelTarget
66

77

88
def get_sql_completions(context: t.Optional[LSPContext], file_uri: str) -> AllModelsResponse:
@@ -24,11 +24,20 @@ def get_models(context: t.Optional[LSPContext], file_uri: t.Optional[str]) -> t.
2424
"""
2525
if context is None:
2626
return set()
27-
all_models = set(model for models in context.map.values() for model in models)
28-
if file_uri is not None:
29-
models_file_refers_to = context.map[file_uri]
30-
for model in models_file_refers_to:
31-
all_models.discard(model)
27+
28+
all_models = set()
29+
# Extract model names from ModelInfo objects
30+
for file_info in context.map.values():
31+
if isinstance(file_info, ModelTarget):
32+
all_models.update(file_info.names)
33+
34+
# Remove models from the current file
35+
if file_uri is not None and file_uri in context.map:
36+
file_info = context.map[file_uri]
37+
if isinstance(file_info, ModelTarget):
38+
for model in file_info.names:
39+
all_models.discard(model)
40+
3241
return all_models
3342

3443

@@ -43,16 +52,25 @@ def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[str]) ->
4352
If both a context and a file_uri are provided, returns the keywords
4453
for the dialect of the model that the file belongs to.
4554
"""
46-
if file_uri is not None and context is not None:
47-
models = context.map[file_uri]
48-
if models:
49-
model = models[0]
50-
model_from_context = context.context.get_model(model)
51-
if model_from_context is not None:
52-
if model_from_context.dialect:
53-
return get_keywords_from_tokenizer(model_from_context.dialect)
55+
if file_uri is not None and context is not None and file_uri in context.map:
56+
file_info = context.map[file_uri]
57+
58+
# Handle ModelInfo objects
59+
if hasattr(file_info, "names") and file_info.names:
60+
model_name = file_info.names[0]
61+
model_from_context = context.context.get_model(model_name)
62+
if model_from_context is not None and model_from_context.dialect:
63+
return get_keywords_from_tokenizer(model_from_context.dialect)
64+
65+
# Handle AuditInfo objects
66+
elif hasattr(file_info, "name"):
67+
audit = context.context.standalone_audits.get(file_info.name)
68+
if audit is not None and audit.dialect:
69+
return get_keywords_from_tokenizer(audit.dialect)
70+
5471
if context is not None:
5572
return get_keywords_from_tokenizer(context.context.default_dialect)
73+
5674
return get_keywords_from_tokenizer(None)
5775

5876

sqlmesh/lsp/context.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,57 @@
1-
from collections import defaultdict
1+
from dataclasses import dataclass
22
from pathlib import Path
33
from sqlmesh.core.context import Context
44
import typing as t
55

66

7+
@dataclass
8+
class ModelTarget:
9+
"""Information about models in a file."""
10+
11+
names: t.List[str]
12+
13+
14+
@dataclass
15+
class AuditTarget:
16+
"""Information about standalone audits in a file."""
17+
18+
name: str
19+
20+
721
class LSPContext:
822
"""
9-
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
23+
A context that is used for linting. It contains the context and a reverse map of file uri to
24+
model names and standalone audit names.
1025
"""
1126

1227
def __init__(self, context: Context) -> None:
1328
self.context = context
14-
map: t.Dict[str, t.List[str]] = defaultdict(list)
29+
30+
# Add models to the map
31+
model_map: t.Dict[str, ModelTarget] = {}
1532
for model in context.models.values():
1633
if model._path is not None:
1734
path = Path(model._path).resolve()
18-
map[f"file://{path.as_posix()}"].append(model.name)
35+
uri = f"file://{path.as_posix()}"
36+
37+
if uri in model_map:
38+
model_map[uri].names.append(model.name)
39+
else:
40+
model_map[uri] = ModelTarget(names=[model.name])
41+
42+
# Add standalone audits to the map
43+
audit_map: t.Dict[str, AuditTarget] = {}
44+
for audit in context.standalone_audits.values():
45+
if audit._path is not None:
46+
path = Path(audit._path).resolve()
47+
uri = f"file://{path.as_posix()}"
48+
# Only add if not already in map (prioritize models if both exist in same file)
49+
if uri not in audit_map:
50+
audit_map[uri] = AuditTarget(name=audit.name)
1951

52+
# Maps file URIs to either ModelInfo or AuditInfo
53+
map: t.Dict[str, t.Union[ModelTarget, AuditTarget]] = {
54+
**model_map,
55+
**audit_map,
56+
}
2057
self.map = map

sqlmesh/lsp/main.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlmesh.core.context import Context
1313
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
1414
from sqlmesh.lsp.completions import get_sql_completions
15-
from sqlmesh.lsp.context import LSPContext
15+
from sqlmesh.lsp.context import LSPContext, ModelTarget
1616
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1717
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
1818

@@ -62,8 +62,10 @@ def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> Non
6262
models = context.map[params.text_document.uri]
6363
if models is None:
6464
return
65+
if not isinstance(models, ModelTarget):
66+
return
6567
self.lint_cache[params.text_document.uri] = context.context.lint_models(
66-
models,
68+
models.names,
6769
raise_on_error=False,
6870
)
6971
ls.publish_diagnostics(
@@ -79,8 +81,10 @@ def did_change(ls: LanguageServer, params: types.DidChangeTextDocumentParams) ->
7981
models = context.map[params.text_document.uri]
8082
if models is None:
8183
return
84+
if not isinstance(models, ModelTarget):
85+
return
8286
self.lint_cache[params.text_document.uri] = context.context.lint_models(
83-
models,
87+
models.names,
8488
raise_on_error=False,
8589
)
8690
ls.publish_diagnostics(
@@ -96,8 +100,10 @@ def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> Non
96100
models = context.map[params.text_document.uri]
97101
if models is None:
98102
return
103+
if not isinstance(models, ModelTarget):
104+
return
99105
self.lint_cache[params.text_document.uri] = context.context.lint_models(
100-
models,
106+
models.names,
101107
raise_on_error=False,
102108
)
103109
ls.publish_diagnostics(

sqlmesh/lsp/reference.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from sqlmesh.core.dialect import normalize_model_name
55
from sqlmesh.core.model.definition import SqlModel
6-
from sqlmesh.lsp.context import LSPContext
6+
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
77
from sqlglot import exp
88

99
from sqlmesh.utils.pydantic import PydanticModel
@@ -20,7 +20,7 @@ def get_model_definitions_for_a_path(
2020
"""
2121
Get the model references for a given path.
2222
23-
Works for models and audits.
23+
Works for models and standalone audits.
2424
Works for targeting sql and python models.
2525
2626
Steps:
@@ -31,39 +31,63 @@ def get_model_definitions_for_a_path(
3131
- Try get_model before normalization
3232
- Match to models that the model refers to
3333
"""
34-
# Ensure the path is a sql model
34+
# Ensure the path is a sql file
3535
if not document_uri.endswith(".sql"):
3636
return []
3737

38-
# Get the model
39-
models = lint_context.map[document_uri]
40-
if not models:
38+
# Get the file info from the context map
39+
if document_uri not in lint_context.map:
4140
return []
42-
model = lint_context.context.get_model(model_or_snapshot=models[0], raise_if_missing=False)
43-
if model is None or not isinstance(model, SqlModel):
41+
42+
file_info = lint_context.map[document_uri]
43+
44+
# Process based on whether it's a model or standalone audit
45+
if isinstance(file_info, ModelTarget):
46+
# It's a model
47+
model = lint_context.context.get_model(
48+
model_or_snapshot=file_info.names[0], raise_if_missing=False
49+
)
50+
if model is None or not isinstance(model, SqlModel):
51+
return []
52+
53+
query = model.query
54+
dialect = model.dialect
55+
depends_on = model.depends_on
56+
file_path = model._path
57+
elif isinstance(file_info, AuditTarget):
58+
# It's a standalone audit
59+
audit = lint_context.context.standalone_audits.get(file_info.name)
60+
if audit is None:
61+
return []
62+
63+
query = audit.query
64+
dialect = audit.dialect
65+
depends_on = audit.depends_on
66+
file_path = audit._path
67+
else:
4468
return []
4569

4670
# Find all possible references
4771
references = []
48-
tables = list(model.query.find_all(exp.Table))
72+
73+
# Get SQL query and find all table references
74+
tables = list(query.find_all(exp.Table))
4975
if len(tables) == 0:
5076
return []
5177

52-
read_file = open(model._path, "r").readlines()
78+
read_file = open(file_path, "r").readlines()
5379

5480
for table in tables:
55-
depends_on = model.depends_on
56-
5781
# Normalize the table reference
5882
unaliased = table.copy()
5983
if unaliased.args.get("alias") is not None:
6084
unaliased.set("alias", None)
61-
reference_name = unaliased.sql(dialect=model.dialect)
85+
reference_name = unaliased.sql(dialect=dialect)
6286
try:
6387
normalized_reference_name = normalize_model_name(
6488
reference_name,
6589
default_catalog=lint_context.context.default_catalog,
66-
dialect=model.dialect,
90+
dialect=dialect,
6791
)
6892
if normalized_reference_name not in depends_on:
6993
continue

tests/lsp/test_context.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from sqlmesh.core.context import Context
3-
from sqlmesh.lsp.context import LSPContext
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
44

55

66
@pytest.mark.fast
@@ -16,4 +16,7 @@ def test_lsp_context():
1616
active_customers_key = next(
1717
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
1818
)
19-
assert lsp_context.map[active_customers_key] == ["sushi.active_customers"]
19+
20+
# Check that the value is a ModelInfo with the expected model name
21+
assert isinstance(lsp_context.map[active_customers_key], ModelTarget)
22+
assert "sushi.active_customers" in lsp_context.map[active_customers_key].names

tests/lsp/test_reference.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from sqlmesh.core.context import Context
3-
from sqlmesh.lsp.context import LSPContext
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
44
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
55

66

@@ -9,11 +9,16 @@ def test_reference() -> None:
99
context = Context(paths=["examples/sushi"])
1010
lsp_context = LSPContext(context)
1111

12+
# Find model URIs
1213
active_customers_uri = next(
13-
uri for uri, models in lsp_context.map.items() if "sushi.active_customers" in models
14+
uri
15+
for uri, info in lsp_context.map.items()
16+
if isinstance(info, ModelTarget) and "sushi.active_customers" in info.names
1417
)
1518
sushi_customers_uri = next(
16-
uri for uri, models in lsp_context.map.items() if "sushi.customers" in models
19+
uri
20+
for uri, info in lsp_context.map.items()
21+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
1722
)
1823

1924
references = get_model_definitions_for_a_path(lsp_context, active_customers_uri)
@@ -35,7 +40,9 @@ def test_reference_with_alias() -> None:
3540
lsp_context = LSPContext(context)
3641

3742
waiter_revenue_by_day_uri = next(
38-
uri for uri, models in lsp_context.map.items() if "sushi.waiter_revenue_by_day" in models
43+
uri
44+
for uri, info in lsp_context.map.items()
45+
if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names
3946
)
4047

4148
references = get_model_definitions_for_a_path(lsp_context, waiter_revenue_by_day_uri)
@@ -52,6 +59,37 @@ def test_reference_with_alias() -> None:
5259
assert get_string_from_range(read_file, references[2].range) == "sushi.items"
5360

5461

62+
@pytest.mark.fast
63+
def test_standalone_audit_reference() -> None:
64+
context = Context(paths=["examples/sushi"])
65+
lsp_context = LSPContext(context)
66+
67+
# Find the standalone audit URI
68+
audit_uri = next(
69+
uri
70+
for uri, info in lsp_context.map.items()
71+
if isinstance(info, AuditTarget) and info.name == "assert_item_price_above_zero"
72+
)
73+
74+
# Find the items model URI
75+
items_uri = next(
76+
uri
77+
for uri, info in lsp_context.map.items()
78+
if isinstance(info, ModelTarget) and "sushi.items" in info.names
79+
)
80+
81+
references = get_model_definitions_for_a_path(lsp_context, audit_uri)
82+
83+
assert len(references) == 1
84+
assert references[0].uri == items_uri
85+
86+
# Check that the reference in the correct range is sushi.items
87+
path = audit_uri.removeprefix("file://")
88+
read_file = open(path, "r").readlines()
89+
referenced_text = get_string_from_range(read_file, references[0].range)
90+
assert referenced_text == "sushi.items"
91+
92+
5593
def get_string_from_range(file_lines, range_obj) -> str:
5694
start_line = range_obj.start.line
5795
end_line = range_obj.end.line

0 commit comments

Comments
 (0)