diff --git a/sqlmesh/lsp/helpers.py b/sqlmesh/lsp/helpers.py index 7aa06ccb4c..920a93f5c7 100644 --- a/sqlmesh/lsp/helpers.py +++ b/sqlmesh/lsp/helpers.py @@ -6,22 +6,35 @@ ) -def to_lsp_range( - range: SQLMeshRange, -) -> Range: +def to_sqlmesh_position(position: Position) -> SQLMeshPosition: """ - Converts a SQLMesh Range to an LSP Range. + Converts an LSP Position to a SQLMesh Position. """ - return Range( - start=Position(line=range.start.line, character=range.start.character), - end=Position(line=range.end.line, character=range.end.character), - ) + return SQLMeshPosition(line=position.line, character=position.character) -def to_lsp_position( - position: SQLMeshPosition, -) -> Position: +def to_lsp_position(position: SQLMeshPosition) -> Position: """ Converts a SQLMesh Position to an LSP Position. """ return Position(line=position.line, character=position.character) + + +def to_sqlmesh_range(range: Range) -> SQLMeshRange: + """ + Converts an LSP Range to a SQLMesh Range. + """ + return SQLMeshRange( + start=to_sqlmesh_position(range.start), + end=to_sqlmesh_position(range.end), + ) + + +def to_lsp_range(range: SQLMeshRange) -> Range: + """ + Converts a SQLMesh Range to an LSP Range. + """ + return Range( + start=to_lsp_position(range.start), + end=to_lsp_position(range.end), + ) diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index 3247bd5b50..edce6c63e0 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -48,6 +48,7 @@ from sqlmesh.lsp.reference import ( LSPCteReference, LSPModelReference, + LSPExternalModelReference, get_references, get_all_references, ) @@ -444,11 +445,19 @@ def goto_definition( references = get_references(self.lsp_context, uri, params.position) location_links = [] for reference in references: - # Use target_range if available (CTEs, Macros), otherwise default to start of file - if not isinstance(reference, LSPModelReference): - target_range = reference.target_range - target_selection_range = reference.target_range - else: + # Use target_range if available (CTEs, Macros, and external models in YAML) + if isinstance(reference, LSPModelReference): + # Regular SQL models - default to start of file + target_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ) + target_selection_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ) + elif isinstance(reference, LSPExternalModelReference): + # External models may have target_range set for YAML files target_range = types.Range( start=types.Position(line=0, character=0), end=types.Position(line=0, character=0), @@ -457,6 +466,13 @@ def goto_definition( start=types.Position(line=0, character=0), end=types.Position(line=0, character=0), ) + if reference.target_range is not None: + target_range = reference.target_range + target_selection_range = reference.target_range + else: + # CTEs and Macros always have target_range + target_range = reference.target_range + target_selection_range = reference.target_range location_links.append( types.LocationLink( diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py index e401be898c..ac4d5374b6 100644 --- a/sqlmesh/lsp/reference.py +++ b/sqlmesh/lsp/reference.py @@ -8,7 +8,7 @@ from sqlmesh.core.linter.helpers import ( TokenPositionDetails, ) -from sqlmesh.core.model.definition import SqlModel +from sqlmesh.core.model.definition import SqlModel, ExternalModel from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget from sqlglot import exp from sqlmesh.lsp.description import generate_markdown_description @@ -22,10 +22,11 @@ from sqlmesh.core.model import Model from sqlmesh import macro import inspect +from ruamel.yaml import YAML class LSPModelReference(PydanticModel): - """A LSP reference to a model.""" + """A LSP reference to a model, excluding external models.""" type: t.Literal["model"] = "model" uri: str @@ -33,6 +34,16 @@ class LSPModelReference(PydanticModel): markdown_description: t.Optional[str] = None +class LSPExternalModelReference(PydanticModel): + """A LSP reference to an external model.""" + + type: t.Literal["external_model"] = "external_model" + uri: str + range: Range + markdown_description: t.Optional[str] = None + target_range: t.Optional[Range] = None + + class LSPCteReference(PydanticModel): """A LSP reference to a CTE.""" @@ -53,7 +64,8 @@ class LSPMacroReference(PydanticModel): Reference = t.Annotated[ - t.Union[LSPModelReference, LSPCteReference, LSPMacroReference], Field(discriminator="type") + t.Union[LSPModelReference, LSPCteReference, LSPMacroReference, LSPExternalModelReference], + Field(discriminator="type"), ] @@ -243,16 +255,38 @@ def get_model_definitions_for_a_path( description = generate_markdown_description(referenced_model) - references.append( - LSPModelReference( - uri=referenced_model_uri.value, - range=Range( - start=to_lsp_position(start_pos_sqlmesh), - end=to_lsp_position(end_pos_sqlmesh), - ), - markdown_description=description, + # For external models in YAML files, find the specific model block + if isinstance(referenced_model, ExternalModel): + yaml_target_range: t.Optional[Range] = None + if ( + referenced_model_path.suffix in (".yaml", ".yml") + and referenced_model_path.is_file() + ): + yaml_target_range = _get_yaml_model_range( + referenced_model_path, referenced_model.name + ) + references.append( + LSPExternalModelReference( + uri=referenced_model_uri.value, + range=Range( + start=to_lsp_position(start_pos_sqlmesh), + end=to_lsp_position(end_pos_sqlmesh), + ), + markdown_description=description, + target_range=yaml_target_range, + ) + ) + else: + references.append( + LSPModelReference( + uri=referenced_model_uri.value, + range=Range( + start=to_lsp_position(start_pos_sqlmesh), + end=to_lsp_position(end_pos_sqlmesh), + ), + markdown_description=description, + ) ) - ) return references @@ -699,3 +733,31 @@ def _position_within_range(position: Position, range: Range) -> bool: range.end.line > position.line or (range.end.line == position.line and range.end.character >= position.character) ) + + +def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]: + """ + Find the range of a specific model block in a YAML file. + + Args: + yaml_path: Path to the YAML file + model_name: Name of the model to find + + Returns: + The Range of the model block in the YAML file, or None if not found + """ + yaml = YAML() + with path.open("r", encoding="utf-8") as f: + data = yaml.load(f) + + if not isinstance(data, list): + return None + + for item in data: + if isinstance(item, dict) and item.get("name") == model_name: + # Get size of block by taking the earliest line/col in the items block and the last line/col of the block + position_data = item.lc.data["name"] # type: ignore + start = Position(line=position_data[2], character=position_data[3]) + end = Position(line=position_data[2], character=position_data[3] + len(item["name"])) + return Range(start=start, end=end) + return None diff --git a/tests/lsp/test_reference_external_model.py b/tests/lsp/test_reference_external_model.py new file mode 100644 index 0000000000..ebf6420934 --- /dev/null +++ b/tests/lsp/test_reference_external_model.py @@ -0,0 +1,38 @@ +from lsprotocol.types import Position +from sqlmesh.core.context import Context +from sqlmesh.core.linter.helpers import read_range_from_file +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.helpers import to_sqlmesh_range +from sqlmesh.lsp.reference import get_references, LSPExternalModelReference +from sqlmesh.lsp.uri import URI + + +def test_reference() -> None: + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find model URIs + customers = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + # Position of reference in file sushi.customers for sushi.raw_demographics + position = Position(line=42, character=20) + references = get_references(lsp_context, URI.from_path(customers), position) + + assert len(references) == 1 + reference = references[0] + assert isinstance(reference, LSPExternalModelReference) + assert reference.uri.endswith("external_models.yaml") + + source_range = read_range_from_file(customers, to_sqlmesh_range(reference.range)) + assert source_range == "raw.demographics" + + if reference.target_range is None: + raise AssertionError("Reference target range should not be None") + target_range = read_range_from_file( + URI(reference.uri).to_path(), to_sqlmesh_range(reference.target_range) + ) + assert target_range == "raw.demographics"