Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
get_references,
get_all_references,
)
from sqlmesh.lsp.rename import prepare_rename, rename_symbol, get_document_highlights
from sqlmesh.lsp.uri import URI
from web.server.api.endpoints.lineage import column_lineage, model_lineage
from web.server.api.endpoints.models import get_models
Expand Down Expand Up @@ -435,6 +436,59 @@ def find_references(
ls.show_message(f"Error getting locations: {e}", types.MessageType.Error)
return None

@self.server.feature(types.TEXT_DOCUMENT_PREPARE_RENAME)
def prepare_rename_handler(
ls: LanguageServer, params: types.PrepareRenameParams
) -> t.Optional[types.PrepareRenameResult]:
"""Prepare for rename operation by checking if the symbol can be renamed."""
try:
uri = URI(params.text_document.uri)
self._ensure_context_for_document(uri)
if self.lsp_context is None:
raise RuntimeError(f"No context found for document: {uri}")

result = prepare_rename(self.lsp_context, uri, params.position)
return result
except Exception as e:
ls.log_trace(f"Error preparing rename: {e}")
Comment thread
themisvaltinos marked this conversation as resolved.
return None

@self.server.feature(types.TEXT_DOCUMENT_RENAME)
def rename_handler(
ls: LanguageServer, params: types.RenameParams
) -> t.Optional[types.WorkspaceEdit]:
"""Perform rename operation on the symbol at the given position."""
try:
uri = URI(params.text_document.uri)
self._ensure_context_for_document(uri)
if self.lsp_context is None:
raise RuntimeError(f"No context found for document: {uri}")

workspace_edit = rename_symbol(
self.lsp_context, uri, params.position, params.new_name
)
return workspace_edit
except Exception as e:
ls.show_message(f"Error performing rename: {e}", types.MessageType.Error)
return None

@self.server.feature(types.TEXT_DOCUMENT_DOCUMENT_HIGHLIGHT)
def document_highlight_handler(
ls: LanguageServer, params: types.DocumentHighlightParams
) -> t.Optional[t.List[types.DocumentHighlight]]:
"""Highlight all occurrences of the symbol at the given position."""
try:
uri = URI(params.text_document.uri)
self._ensure_context_for_document(uri)
if self.lsp_context is None:
raise RuntimeError(f"No context found for document: {uri}")

highlights = get_document_highlights(self.lsp_context, uri, params.position)
return highlights
except Exception as e:
ls.log_trace(f"Error getting document highlights: {e}")
return None

@self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC)
def diagnostic(
ls: LanguageServer, params: types.DocumentDiagnosticParams
Expand Down
137 changes: 137 additions & 0 deletions sqlmesh/lsp/rename.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import typing as t
from lsprotocol.types import (
Position,
TextEdit,
WorkspaceEdit,
PrepareRenameResult_Type1,
DocumentHighlight,
DocumentHighlightKind,
)

from sqlmesh.lsp.context import LSPContext
from sqlmesh.lsp.reference import (
_position_within_range,
get_cte_references,
LSPCteReference,
)
from sqlmesh.lsp.uri import URI


def prepare_rename(
lsp_context: LSPContext, document_uri: URI, position: Position
) -> t.Optional[PrepareRenameResult_Type1]:
"""
Prepare for rename operation by checking if the symbol at the position can be renamed.

Args:
lsp_context: The LSP context
document_uri: The URI of the document
position: The position in the document

Returns:
PrepareRenameResult if the symbol can be renamed, None otherwise
"""
# Check if there's a CTE at this position
cte_references = get_cte_references(lsp_context, document_uri, position)
if cte_references:
# Find the target CTE definition to get its range
target_range = None
for ref in cte_references:
# Check if cursor is on a CTE usage
if _position_within_range(position, ref.range):
target_range = ref.target_range
break
# Check if cursor is on the CTE definition
elif _position_within_range(position, ref.target_range):
target_range = ref.target_range
break
if target_range:
return PrepareRenameResult_Type1(range=target_range, placeholder="cte_name")

# For now, only CTEs are supported
return None


def rename_symbol(
lsp_context: LSPContext, document_uri: URI, position: Position, new_name: str
) -> t.Optional[WorkspaceEdit]:
"""
Perform rename operation on the symbol at the given position.

Args:
lsp_context: The LSP context
document_uri: The URI of the document
position: The position in the document
new_name: The new name for the symbol

Returns:
WorkspaceEdit with the changes, or None if no symbol to rename
"""
Comment thread
themisvaltinos marked this conversation as resolved.
# Check if there's a CTE at this position
cte_references = get_cte_references(lsp_context, document_uri, position)
if cte_references:
return _rename_cte(cte_references, new_name)

# For now, only CTEs are supported
return None


def _rename_cte(cte_references: t.List[LSPCteReference], new_name: str) -> WorkspaceEdit:
"""
Create a WorkspaceEdit for renaming a CTE.

Args:
cte_references: List of CTE references (definition and usages)
new_name: The new name for the CTE

Returns:
WorkspaceEdit with the text edits for renaming the CTE
"""
changes: t.Dict[str, t.List[TextEdit]] = {}

for ref in cte_references:
uri = ref.uri
if uri not in changes:
changes[uri] = []

# Create a text edit for this reference
text_edit = TextEdit(range=ref.range, new_text=new_name)
changes[uri].append(text_edit)

return WorkspaceEdit(changes=changes)


def get_document_highlights(
lsp_context: LSPContext, document_uri: URI, position: Position
) -> t.Optional[t.List[DocumentHighlight]]:
"""
Get document highlights for all occurrences of the symbol at the given position.

This function finds all occurrences of a symbol (CTE) within the current document
and returns them as DocumentHighlight objects for "Change All Occurrences" feature.

Args:
lsp_context: The LSP context
document_uri: The URI of the document
position: The position in the document to find highlights for

Returns:
List of DocumentHighlight objects or None if no symbol found
"""
# Check if there's a CTE at this position
cte_references = get_cte_references(lsp_context, document_uri, position)
if cte_references:
highlights = []
for ref in cte_references:
# Determine the highlight kind based on whether it's a definition or usage
kind = (
DocumentHighlightKind.Write
if ref.range == ref.target_range
else DocumentHighlightKind.Read
)

highlights.append(DocumentHighlight(range=ref.range, kind=kind))
return highlights

# For now, only CTEs are supported
return None
111 changes: 111 additions & 0 deletions tests/lsp/test_document_highlight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from lsprotocol.types import Position, DocumentHighlightKind

from sqlmesh.core.context import Context
from sqlmesh.lsp.context import LSPContext, ModelTarget
from sqlmesh.lsp.rename import get_document_highlights
from sqlmesh.lsp.uri import URI
from tests.lsp.test_reference_cte import find_ranges_from_regex


def test_get_document_highlights_cte():
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

# Use the existing customers.sql model which has CTEs
sushi_customers_path = next(
path
for path, info in lsp_context.map.items()
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
)

with open(sushi_customers_path, "r", encoding="utf-8") as file:
read_file = file.readlines()

test_uri = URI.from_path(sushi_customers_path)

# Find the ranges for "current_marketing" CTE (not outer one)
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
assert len(ranges) >= 2 # Should have definition + usage

# Test highlighting CTE definition - position on "current_marketing" definition
position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4)
highlights = get_document_highlights(lsp_context, test_uri, position)

assert highlights is not None
assert len(highlights) >= 2 # Definition + at least 1 usage

# Check that we have both definition (Write) and usage (Read) highlights
highlight_kinds = [h.kind for h in highlights]
assert DocumentHighlightKind.Write in highlight_kinds # CTE definition
assert DocumentHighlightKind.Read in highlight_kinds # CTE usage

# Test highlighting CTE usage - position on "current_marketing" usage
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
highlights = get_document_highlights(lsp_context, test_uri, position)

assert highlights is not None
assert len(highlights) >= 2 # Should find the same references


def test_get_document_highlights_no_symbol():
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

# Use the existing customers.sql model
sushi_customers_path = next(
path
for path, info in lsp_context.map.items()
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
)

test_uri = URI.from_path(sushi_customers_path)

# Test position not on any CTE symbol - just on a random keyword
position = Position(line=5, character=5)
highlights = get_document_highlights(lsp_context, test_uri, position)

assert highlights is None


def test_get_document_highlights_multiple_ctes():
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

# Use the existing customers.sql model which has both outer and inner CTEs
sushi_customers_path = next(
path
for path, info in lsp_context.map.items()
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
)

with open(sushi_customers_path, "r", encoding="utf-8") as file:
read_file = file.readlines()

test_uri = URI.from_path(sushi_customers_path)

# Test the outer CTE - "current_marketing_outer"
outer_ranges = find_ranges_from_regex(read_file, r"current_marketing_outer")
assert len(outer_ranges) >= 2 # Should have definition + usage

# Test highlighting outer CTE - should only highlight that CTE
position = Position(
line=outer_ranges[0].start.line, character=outer_ranges[0].start.character + 4
)
highlights = get_document_highlights(lsp_context, test_uri, position)

assert highlights is not None
assert len(highlights) == len(outer_ranges) # Should match all occurrences of outer CTE

# Test the inner CTE - "current_marketing" (not outer)
inner_ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
assert len(inner_ranges) >= 2 # Should have definition + usage

# Test highlighting inner CTE - should only highlight that CTE, not the outer one
position = Position(
line=inner_ranges[0].start.line, character=inner_ranges[0].start.character + 4
)
highlights = get_document_highlights(lsp_context, test_uri, position)

# This should return the column usages as well
assert highlights is not None
assert len(highlights) == 4
Loading