Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions sqlmesh/lsp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,35 @@
)


def to_lsp_range(
range: SQLMeshRange,
) -> Range:
def to_sqlmesh_position(position: Position) -> SQLMeshPosition:
"""
Converts a SQLMesh Range to an LSP Range.
Converts an LSP Position to a SQLMesh Position.
"""
return Range(
start=Position(line=range.start.line, character=range.start.character),
end=Position(line=range.end.line, character=range.end.character),
)
return SQLMeshPosition(line=position.line, character=position.character)


def to_lsp_position(
position: SQLMeshPosition,
) -> Position:
def to_lsp_position(position: SQLMeshPosition) -> Position:
"""
Converts a SQLMesh Position to an LSP Position.
"""
return Position(line=position.line, character=position.character)


def to_sqlmesh_range(range: Range) -> SQLMeshRange:
"""
Converts an LSP Range to a SQLMesh Range.
"""
return SQLMeshRange(
start=to_sqlmesh_position(range.start),
end=to_sqlmesh_position(range.end),
)


def to_lsp_range(range: SQLMeshRange) -> Range:
"""
Converts a SQLMesh Range to an LSP Range.
"""
return Range(
start=to_lsp_position(range.start),
end=to_lsp_position(range.end),
)
26 changes: 21 additions & 5 deletions sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from sqlmesh.lsp.reference import (
LSPCteReference,
LSPModelReference,
LSPExternalModelReference,
get_references,
get_all_references,
)
Expand Down Expand Up @@ -444,11 +445,19 @@ def goto_definition(
references = get_references(self.lsp_context, uri, params.position)
location_links = []
for reference in references:
# Use target_range if available (CTEs, Macros), otherwise default to start of file
if not isinstance(reference, LSPModelReference):
target_range = reference.target_range
target_selection_range = reference.target_range
else:
# Use target_range if available (CTEs, Macros, and external models in YAML)
if isinstance(reference, LSPModelReference):
# Regular SQL models - default to start of file
target_range = types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=0, character=0),
)
target_selection_range = types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=0, character=0),
)
elif isinstance(reference, LSPExternalModelReference):
# External models may have target_range set for YAML files
target_range = types.Range(
start=types.Position(line=0, character=0),
end=types.Position(line=0, character=0),
Expand All @@ -457,6 +466,13 @@ def goto_definition(
start=types.Position(line=0, character=0),
end=types.Position(line=0, character=0),
)
if reference.target_range is not None:
target_range = reference.target_range
target_selection_range = reference.target_range
else:
Comment thread
benfdking marked this conversation as resolved.
# CTEs and Macros always have target_range
target_range = reference.target_range
target_selection_range = reference.target_range

location_links.append(
types.LocationLink(
Expand Down
86 changes: 74 additions & 12 deletions sqlmesh/lsp/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlmesh.core.linter.helpers import (
TokenPositionDetails,
)
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.core.model.definition import SqlModel, ExternalModel
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
from sqlglot import exp
from sqlmesh.lsp.description import generate_markdown_description
Expand All @@ -22,17 +22,28 @@
from sqlmesh.core.model import Model
from sqlmesh import macro
import inspect
from ruamel.yaml import YAML


class LSPModelReference(PydanticModel):
"""A LSP reference to a model."""
"""A LSP reference to a model, excluding external models."""

type: t.Literal["model"] = "model"
uri: str
range: Range
markdown_description: t.Optional[str] = None


class LSPExternalModelReference(PydanticModel):
"""A LSP reference to an external model."""

type: t.Literal["external_model"] = "external_model"
uri: str
range: Range
markdown_description: t.Optional[str] = None
target_range: t.Optional[Range] = None


class LSPCteReference(PydanticModel):
"""A LSP reference to a CTE."""

Expand All @@ -53,7 +64,8 @@ class LSPMacroReference(PydanticModel):


Reference = t.Annotated[
t.Union[LSPModelReference, LSPCteReference, LSPMacroReference], Field(discriminator="type")
t.Union[LSPModelReference, LSPCteReference, LSPMacroReference, LSPExternalModelReference],
Field(discriminator="type"),
]


Expand Down Expand Up @@ -243,16 +255,38 @@ def get_model_definitions_for_a_path(

description = generate_markdown_description(referenced_model)

references.append(
LSPModelReference(
uri=referenced_model_uri.value,
range=Range(
start=to_lsp_position(start_pos_sqlmesh),
end=to_lsp_position(end_pos_sqlmesh),
),
markdown_description=description,
# For external models in YAML files, find the specific model block
if isinstance(referenced_model, ExternalModel):
yaml_target_range: t.Optional[Range] = None
if (
referenced_model_path.suffix in (".yaml", ".yml")
and referenced_model_path.is_file()
):
yaml_target_range = _get_yaml_model_range(
referenced_model_path, referenced_model.name
)
references.append(
LSPExternalModelReference(
uri=referenced_model_uri.value,
range=Range(
start=to_lsp_position(start_pos_sqlmesh),
end=to_lsp_position(end_pos_sqlmesh),
),
markdown_description=description,
target_range=yaml_target_range,
)
)
else:
references.append(
LSPModelReference(
uri=referenced_model_uri.value,
range=Range(
start=to_lsp_position(start_pos_sqlmesh),
end=to_lsp_position(end_pos_sqlmesh),
),
markdown_description=description,
)
)
)

return references

Expand Down Expand Up @@ -699,3 +733,31 @@ def _position_within_range(position: Position, range: Range) -> bool:
range.end.line > position.line
or (range.end.line == position.line and range.end.character >= position.character)
)


def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]:
"""
Find the range of a specific model block in a YAML file.

Args:
yaml_path: Path to the YAML file
model_name: Name of the model to find

Returns:
The Range of the model block in the YAML file, or None if not found
"""
yaml = YAML()
with path.open("r", encoding="utf-8") as f:
data = yaml.load(f)

if not isinstance(data, list):
return None

for item in data:
if isinstance(item, dict) and item.get("name") == model_name:
# Get size of block by taking the earliest line/col in the items block and the last line/col of the block
position_data = item.lc.data["name"] # type: ignore
start = Position(line=position_data[2], character=position_data[3])
end = Position(line=position_data[2], character=position_data[3] + len(item["name"]))
Comment thread
benfdking marked this conversation as resolved.
return Range(start=start, end=end)
return None
38 changes: 38 additions & 0 deletions tests/lsp/test_reference_external_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from lsprotocol.types import Position
from sqlmesh.core.context import Context
from sqlmesh.core.linter.helpers import read_range_from_file
from sqlmesh.lsp.context import LSPContext, ModelTarget
from sqlmesh.lsp.helpers import to_sqlmesh_range
from sqlmesh.lsp.reference import get_references, LSPExternalModelReference
from sqlmesh.lsp.uri import URI


def test_reference() -> None:
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

# Find model URIs
customers = next(
path
for path, info in lsp_context.map.items()
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
)

# Position of reference in file sushi.customers for sushi.raw_demographics
position = Position(line=42, character=20)
references = get_references(lsp_context, URI.from_path(customers), position)

assert len(references) == 1
reference = references[0]
assert isinstance(reference, LSPExternalModelReference)
assert reference.uri.endswith("external_models.yaml")

source_range = read_range_from_file(customers, to_sqlmesh_range(reference.range))
assert source_range == "raw.demographics"

if reference.target_range is None:
raise AssertionError("Reference target range should not be None")
target_range = read_range_from_file(
URI(reference.uri).to_path(), to_sqlmesh_range(reference.target_range)
)
assert target_range == "raw.demographics"