Skip to content

Commit c2258fc

Browse files
committed
feat: add go to definition to lsp
1 parent 3271ae1 commit c2258fc

6 files changed

Lines changed: 298 additions & 17 deletions

File tree

sqlmesh/lsp/context.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from collections import defaultdict
2+
from pathlib import Path
3+
from sqlmesh.core.context import Context
4+
import typing as t
5+
6+
7+
class LSPContext:
8+
"""
9+
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
10+
"""
11+
12+
def __init__(self, context: Context) -> None:
13+
self.context = context
14+
map: t.Dict[str, t.List[str]] = defaultdict(list)
15+
for model in context.models.values():
16+
if model._path is not None:
17+
path = Path(model._path).resolve()
18+
map[f"file://{path.as_posix()}"].append(model.name)
19+
self.map = map

sqlmesh/lsp/main.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals."""
33

4-
from collections import defaultdict
54
import logging
65
import typing as t
76
from pathlib import Path
@@ -12,21 +11,8 @@
1211
from sqlmesh._version import __version__
1312
from sqlmesh.core.context import Context
1413
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
15-
16-
17-
class LSPContext:
18-
"""
19-
A context that is used for linting. It contains the context and a reverse map of file uri to model names .
20-
"""
21-
22-
def __init__(self, context: Context) -> None:
23-
self.context = context
24-
map: t.Dict[str, t.List[str]] = defaultdict(list)
25-
for model in context.models.values():
26-
if model._path is not None:
27-
path = Path(model._path).resolve()
28-
map[f"file://{path.as_posix()}"].append(model.name)
29-
self.map = map
14+
from sqlmesh.lsp.context import LSPContext
15+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
3016

3117

3218
class SQLMeshLanguageServer:
@@ -144,6 +130,43 @@ def formatting(
144130
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
145131
return []
146132

133+
@self.server.feature(types.TEXT_DOCUMENT_DEFINITION)
134+
def goto_definition(
135+
ls: LanguageServer, params: types.DefinitionParams
136+
) -> t.List[types.LocationLink]:
137+
"""Jump to an object's definition."""
138+
try:
139+
self._ensure_context_for_document(params.text_document.uri)
140+
document = ls.workspace.get_document(params.text_document.uri)
141+
if self.lsp_context is None:
142+
raise RuntimeError(f"No context found for document: {document.path}")
143+
144+
references = get_model_definitions_for_a_path(
145+
self.lsp_context, params.text_document.uri
146+
)
147+
if len(references) == 0:
148+
return []
149+
150+
return [
151+
types.LocationLink(
152+
target_uri=reference.uri,
153+
target_selection_range=types.Range(
154+
start=types.Position(line=0, character=0),
155+
end=types.Position(line=0, character=0),
156+
),
157+
target_range=types.Range(
158+
start=types.Position(line=0, character=0),
159+
end=types.Position(line=0, character=0),
160+
),
161+
origin_selection_range=reference.range,
162+
)
163+
for reference in references
164+
]
165+
166+
except Exception as e:
167+
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
168+
return []
169+
147170
def _context_get_or_load(self, document_uri: str) -> LSPContext:
148171
if self.lsp_context is None:
149172
self._ensure_context_for_document(document_uri)

sqlmesh/lsp/reference.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from pathlib import Path
2+
3+
from lsprotocol.types import Range, Position
4+
import typing as t
5+
6+
from sqlmesh.core.dialect import normalize_model_name
7+
from sqlmesh.core.model.definition import SqlModel
8+
from sqlmesh.lsp.context import LSPContext
9+
from sqlglot import exp
10+
11+
from sqlmesh.utils.pydantic import PydanticModel
12+
13+
14+
class Reference(PydanticModel):
15+
range: Range
16+
uri: str
17+
18+
19+
def get_model_definitions_for_a_path(
20+
lint_context: LSPContext, document_uri: str
21+
) -> t.List[Reference]:
22+
"""
23+
Get the model references for a given path.
24+
25+
Works for models and audits.
26+
Works for targeting sql and python models.
27+
28+
Steps:
29+
- Get the parsed query
30+
- Find all table objects using find_all exp.Table
31+
- Match the string against all model names
32+
- Need to normalize it before matching
33+
- Try get_model before normalization
34+
- Match to models that the model refers to
35+
"""
36+
# Ensure the path is a sql model
37+
if not document_uri.endswith(".sql"):
38+
return []
39+
40+
# Get the model
41+
models = lint_context.map[document_uri]
42+
if models is None:
43+
return []
44+
if len(models) == 0:
45+
return []
46+
model_name = models[0]
47+
model = lint_context.context.get_model(model_or_snapshot=model_name, raise_if_missing=False)
48+
if model is None:
49+
return []
50+
if not isinstance(model, SqlModel):
51+
return []
52+
53+
# Find all possible references
54+
tables = list(model.query.find_all(exp.Table))
55+
if len(tables) == 0:
56+
return []
57+
58+
references = []
59+
for table in tables:
60+
depends_on = model.depends_on
61+
62+
# Normalize the table reference
63+
reference_name = table.sql(dialect=model.dialect)
64+
normalized_reference_name = normalize_model_name(
65+
reference_name,
66+
default_catalog=lint_context.context.default_catalog,
67+
dialect=model.dialect,
68+
)
69+
if normalized_reference_name not in depends_on:
70+
continue
71+
72+
# Get the referenced model uri
73+
referenced_model = lint_context.context.get_model(
74+
model_or_snapshot=normalized_reference_name, raise_if_missing=False
75+
)
76+
if referenced_model is None:
77+
continue
78+
# Get the model uri
79+
referenced_model_path = referenced_model._path
80+
if referenced_model_path is None:
81+
continue
82+
# Fully qualify the path in case
83+
path = Path.resolve(Path(referenced_model_path))
84+
referenced_model_uri = f"file://{path}"
85+
read_file = open(path, "r").readlines()
86+
87+
# Extract metadata for positioning
88+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
89+
table_range = _range_from_token_position_details(table_meta, read_file)
90+
start_pos = table_range.start
91+
end_pos = table_range.end
92+
93+
# If there's a database qualifier, adjust the start position
94+
db = table.args.get("db")
95+
if db is not None:
96+
db_meta = TokenPositionDetails.from_meta(db.meta)
97+
db_range = _range_from_token_position_details(db_meta, read_file)
98+
start_pos = db_range.start
99+
100+
# If there's a catalog qualifier, adjust the start position further
101+
catalog = table.args.get("catalog")
102+
if catalog is not None:
103+
catalog_meta = TokenPositionDetails.from_meta(catalog.meta)
104+
catalog_range = _range_from_token_position_details(catalog_meta, read_file)
105+
start_pos = catalog_range.start
106+
107+
references.append(
108+
Reference(uri=referenced_model_uri, range=Range(start=start_pos, end=end_pos))
109+
)
110+
111+
return references
112+
113+
114+
class TokenPositionDetails(PydanticModel):
115+
"""
116+
Details about a token's position in the source code.
117+
118+
Attributes:
119+
line (int): The line that the token ends on.
120+
col (int): The column that the token ends on.
121+
start (int): The start index of the token.
122+
end (int): The ending index of the token.
123+
"""
124+
125+
line: int
126+
col: int
127+
start: int
128+
end: int
129+
130+
@staticmethod
131+
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
132+
return TokenPositionDetails(
133+
line=meta["line"],
134+
col=meta["col"],
135+
start=meta["start"],
136+
end=meta["end"],
137+
)
138+
139+
140+
def _range_from_token_position_details(
141+
token_position_details: TokenPositionDetails, read_file: t.List[str]
142+
) -> Range:
143+
"""
144+
Convert a TokenPositionDetails object to a Range object.
145+
146+
:param token_position_details: Details about a token's position
147+
:param read_file: List of lines from the file
148+
:return: A Range object representing the token's position
149+
"""
150+
# Convert from 1-indexed to 0-indexed for line and column
151+
end_line_0 = token_position_details.line - 1
152+
end_col_0 = token_position_details.col
153+
154+
# Find the start line and column by counting backwards from the end position
155+
start_pos = token_position_details.start
156+
end_pos = token_position_details.end
157+
158+
# Initialize with the end position
159+
start_line_0 = end_line_0
160+
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
161+
162+
# If start_col_0 is negative, we need to go back to previous lines
163+
while start_col_0 < 0 and start_line_0 > 0:
164+
start_line_0 -= 1
165+
start_col_0 += len(read_file[start_line_0])
166+
# Account for newline character
167+
if start_col_0 >= 0:
168+
break
169+
start_col_0 += 1 # For the newline character
170+
171+
# Ensure we don't have negative values
172+
start_col_0 = max(0, start_col_0)
173+
return Range(
174+
start=Position(line=start_line_0, character=start_col_0),
175+
end=Position(line=end_line_0, character=end_col_0),
176+
)

tests/lsp/test_context.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from sqlmesh.core.context import Context
2+
from sqlmesh.lsp.context import LSPContext
3+
4+
5+
def test_lsp_context():
6+
context = Context(paths=["examples/sushi"])
7+
lsp_context = LSPContext(context)
8+
9+
assert lsp_context is not None
10+
assert lsp_context.context is not None
11+
assert lsp_context.map is not None
12+
13+
# find one model in the map
14+
active_customers_key = next(
15+
key for key in lsp_context.map.keys() if key.endswith("models/active_customers.sql")
16+
)
17+
assert lsp_context.map[active_customers_key] == ["sushi.active_customers"]

tests/lsp/test_reference.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from sqlmesh.core.context import Context
2+
from sqlmesh.lsp.context import LSPContext
3+
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
4+
5+
6+
def test_reference() -> None:
7+
context = Context(paths=["examples/sushi"])
8+
lsp_context = LSPContext(context)
9+
10+
active_customers_uri = next(
11+
uri for uri, models in lsp_context.map.items() if "sushi.active_customers" in models
12+
)
13+
sushi_customers_uri = next(
14+
uri for uri, models in lsp_context.map.items() if "sushi.customers" in models
15+
)
16+
17+
references = get_model_definitions_for_a_path(lsp_context, active_customers_uri)
18+
19+
assert len(references) == 1
20+
assert references[0].uri == sushi_customers_uri
21+
22+
# Check that the reference in the correct range is sushi.customers
23+
path = active_customers_uri.removeprefix("file://")
24+
read_file = open(path, "r").readlines()
25+
# Get the string range in the read file
26+
reference_range = references[0].range
27+
start_line = reference_range.start.line
28+
end_line = reference_range.end.line
29+
start_character = reference_range.start.character
30+
end_character = reference_range.end.character
31+
# Get the string from the file
32+
33+
# If the reference spans multiple lines, handle it accordingly
34+
if start_line == end_line:
35+
# Reference is on a single line
36+
line_content = read_file[start_line]
37+
referenced_text = line_content[start_character:end_character]
38+
else:
39+
# Reference spans multiple lines
40+
referenced_text = read_file[start_line][
41+
start_character:
42+
] # First line from start_character to end
43+
for line_num in range(start_line + 1, end_line): # Middle lines (if any)
44+
referenced_text += read_file[line_num]
45+
referenced_text += read_file[end_line][:end_character] # Last line up to end_character
46+
assert referenced_text == "sushi.customers"

vscode/extension/src/lsp/lsp.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export class LSPClient implements Disposable {
2727

2828
const sqlmesh = await sqlmesh_lsp_exec()
2929
if (isErr(sqlmesh)) {
30-
traceError(`Failed to get sqlmesh_lsp_exec, ${sqlmesh.error.type}`)
30+
traceError(`Failed to get sqlmesh_lsp_exec, ${JSON.stringify(sqlmesh.error)}`)
3131
return sqlmesh
3232
}
3333
const workspaceFolders = getWorkspaceFolders()

0 commit comments

Comments
 (0)