From 3147b42cbbc406c7956850843f15da81ec5e0b7d Mon Sep 17 00:00:00 2001 From: Ben King <9087625+benfdking@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:19:41 +0100 Subject: [PATCH] feat(lsp): custom methods always return errors --- sqlmesh/lsp/api.py | 17 ++- sqlmesh/lsp/custom.py | 28 ++-- sqlmesh/lsp/main.py | 261 +++++++++++++++++---------------- tooling/vscode/extensions.json | 3 +- 4 files changed, 170 insertions(+), 139 deletions(-) diff --git a/sqlmesh/lsp/api.py b/sqlmesh/lsp/api.py index 4be256c194..a034283759 100644 --- a/sqlmesh/lsp/api.py +++ b/sqlmesh/lsp/api.py @@ -9,13 +9,16 @@ import typing as t from pydantic import field_validator -from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.lsp.custom import ( + CustomMethodRequestBaseClass, + CustomMethodResponseBaseClass, +) from web.server.models import LineageColumn, Model API_FEATURE = "sqlmesh/api" -class ApiRequest(PydanticModel): +class ApiRequest(CustomMethodRequestBaseClass): """ Request to call the SQLMesh API. This is a generic request that can be used to call any API endpoint. @@ -28,7 +31,11 @@ class ApiRequest(PydanticModel): body: t.Optional[t.Dict[str, t.Any]] = None -class ApiResponseGetModels(PydanticModel): +class BaseAPIResponse(CustomMethodResponseBaseClass): + error: t.Optional[str] = None + + +class ApiResponseGetModels(BaseAPIResponse): """ Response from the SQLMesh API for the get_models endpoint. """ @@ -53,7 +60,7 @@ def sanitize_datetime_fields(cls, data: t.List[Model]) -> t.List[Model]: return data -class ApiResponseGetLineage(PydanticModel): +class ApiResponseGetLineage(BaseAPIResponse): """ Response from the SQLMesh API for the get_lineage endpoint. """ @@ -61,7 +68,7 @@ class ApiResponseGetLineage(PydanticModel): data: t.Dict[str, t.List[str]] -class ApiResponseGetColumnLineage(PydanticModel): +class ApiResponseGetColumnLineage(BaseAPIResponse): """ Response from the SQLMesh API for the get_column_lineage endpoint. """ diff --git a/sqlmesh/lsp/custom.py b/sqlmesh/lsp/custom.py index 72c1ec7917..c432802950 100644 --- a/sqlmesh/lsp/custom.py +++ b/sqlmesh/lsp/custom.py @@ -2,10 +2,20 @@ import typing as t from sqlmesh.utils.pydantic import PydanticModel + +class CustomMethodRequestBaseClass(PydanticModel): + pass + + +class CustomMethodResponseBaseClass(PydanticModel): + # Prefixing, so guaranteed not to collide + response_error: t.Optional[str] = None + + ALL_MODELS_FEATURE = "sqlmesh/all_models" -class AllModelsRequest(PydanticModel): +class AllModelsRequest(CustomMethodRequestBaseClass): """ Request to get all the models that are in the current project. """ @@ -13,7 +23,7 @@ class AllModelsRequest(PydanticModel): textDocument: types.TextDocumentIdentifier -class AllModelsResponse(PydanticModel): +class AllModelsResponse(CustomMethodResponseBaseClass): """ Response to get all the models that are in the current project. """ @@ -26,7 +36,7 @@ class AllModelsResponse(PydanticModel): RENDER_MODEL_FEATURE = "sqlmesh/render_model" -class RenderModelRequest(PydanticModel): +class RenderModelRequest(CustomMethodRequestBaseClass): textDocumentUri: str @@ -41,7 +51,7 @@ class RenderModelEntry(PydanticModel): rendered_query: str -class RenderModelResponse(PydanticModel): +class RenderModelResponse(CustomMethodResponseBaseClass): """ Response to render a model. """ @@ -63,11 +73,11 @@ class ModelForRendering(PydanticModel): uri: str -class AllModelsForRenderRequest(PydanticModel): +class AllModelsForRenderRequest(CustomMethodRequestBaseClass): pass -class AllModelsForRenderResponse(PydanticModel): +class AllModelsForRenderResponse(CustomMethodResponseBaseClass): """ Response to get all the models that are in the current project for rendering purposes. """ @@ -94,7 +104,7 @@ class CustomMethod(PydanticModel): name: str -class SupportedMethodsResponse(PydanticModel): +class SupportedMethodsResponse(CustomMethodResponseBaseClass): """ Response containing all supported custom LSP methods. """ @@ -105,7 +115,7 @@ class SupportedMethodsResponse(PydanticModel): FORMAT_PROJECT_FEATURE = "sqlmesh/format_project" -class FormatProjectRequest(PydanticModel): +class FormatProjectRequest(CustomMethodRequestBaseClass): """ Request to format all models in the current project. """ @@ -113,7 +123,7 @@ class FormatProjectRequest(PydanticModel): pass -class FormatProjectResponse(PydanticModel): +class FormatProjectResponse(CustomMethodResponseBaseClass): """ Response to format project request. """ diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py index fdfb105a63..3247bd5b50 100755 --- a/sqlmesh/lsp/main.py +++ b/sqlmesh/lsp/main.py @@ -35,6 +35,7 @@ AllModelsResponse, AllModelsForRenderRequest, AllModelsForRenderResponse, + CustomMethodResponseBaseClass, RenderModelRequest, RenderModelResponse, SupportedMethodsRequest, @@ -54,15 +55,6 @@ from web.server.api.endpoints.lineage import column_lineage, model_lineage from web.server.api.endpoints.models import get_models -SUPPORTED_CUSTOM_METHODS = [ - ALL_MODELS_FEATURE, - RENDER_MODEL_FEATURE, - ALL_MODELS_FOR_RENDER_FEATURE, - API_FEATURE, - SUPPORTED_METHODS_FEATURE, - FORMAT_PROJECT_FEATURE, -] - class SQLMeshLanguageServer: def __init__( @@ -82,6 +74,22 @@ def __init__( self.workspace_folders: t.List[Path] = [] self.client_supports_pull_diagnostics = False + self._supported_custom_methods: t.Dict[ + str, + t.Callable[ + # mypy unable to recognise the base class + [LanguageServer, t.Any], + t.Any, + ], + ] = { + ALL_MODELS_FEATURE: self._custom_all_models, + RENDER_MODEL_FEATURE: self._custom_render_model, + ALL_MODELS_FOR_RENDER_FEATURE: self._custom_all_models_for_render, + API_FEATURE: self._custom_api, + SUPPORTED_METHODS_FEATURE: self._custom_supported_methods, + FORMAT_PROJECT_FEATURE: self._custom_format_project, + } + # Register LSP features (e.g., formatting, hover, etc.) self._register_features() @@ -105,8 +113,128 @@ def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]: self.server.log_trace(f"Error creating context: {e}") return None + # All the custom LSP methods are registered here and prefixed with _custom + def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse: + uri = URI(params.textDocument.uri) + # Get the document content + content = None + try: + document = ls.workspace.get_text_document(params.textDocument.uri) + content = document.source + except Exception: + pass + try: + context = self._context_get_or_load(uri) + return LSPContext.get_completions(context, uri, content) + except Exception as e: + from sqlmesh.lsp.completions import get_sql_completions + + return get_sql_completions(None, URI(params.textDocument.uri), content) + + def _custom_render_model( + self, ls: LanguageServer, params: RenderModelRequest + ) -> RenderModelResponse: + uri = URI(params.textDocumentUri) + context = self._context_get_or_load(uri) + return RenderModelResponse(models=context.render_model(uri)) + + def _custom_all_models_for_render( + self, ls: LanguageServer, params: AllModelsForRenderRequest + ) -> AllModelsForRenderResponse: + if self.lsp_context is None: + current_path = Path.cwd() + self._ensure_context_in_folder(current_path) + if self.lsp_context is None: + raise RuntimeError("No context found") + return AllModelsForRenderResponse(models=self.lsp_context.list_of_models_for_rendering()) + + def _custom_format_project( + self, ls: LanguageServer, params: FormatProjectRequest + ) -> FormatProjectResponse: + """Format all models in the current project.""" + try: + if self.lsp_context is None: + current_path = Path.cwd() + self._ensure_context_in_folder(current_path) + if self.lsp_context is None: + raise RuntimeError("No context found") + + # Call the format method on the context + self.lsp_context.context.format() + return FormatProjectResponse() + except Exception as e: + ls.log_trace(f"Error formatting project: {e}") + return FormatProjectResponse() + + def _custom_api( + self, ls: LanguageServer, request: ApiRequest + ) -> t.Union[ApiResponseGetModels, ApiResponseGetColumnLineage, ApiResponseGetLineage]: + ls.log_trace(f"API request: {request}") + if self.lsp_context is None: + current_path = Path.cwd() + self._ensure_context_in_folder(current_path) + if self.lsp_context is None: + ls.log_trace("No context found in call") + raise RuntimeError("No context found") + + parsed_url = urllib.parse.urlparse(request.url) + path_parts = parsed_url.path.strip("/").split("/") + + if request.method == "GET": + if path_parts == ["api", "models"]: + # /api/models + return ApiResponseGetModels(data=get_models(self.lsp_context.context)) + + if path_parts[:2] == ["api", "lineage"]: + if len(path_parts) == 3: + # /api/lineage/{model} + model_name = urllib.parse.unquote(path_parts[2]) + lineage = model_lineage(model_name, self.lsp_context.context) + non_set_lineage = {k: v for k, v in lineage.items() if v is not None} + return ApiResponseGetLineage(data=non_set_lineage) + + if len(path_parts) == 4: + # /api/lineage/{model}/{column} + model_name = urllib.parse.unquote(path_parts[2]) + column = urllib.parse.unquote(path_parts[3]) + models_only = False + if hasattr(request, "params"): + models_only = bool(getattr(request.params, "models_only", False)) + column_lineage_response = column_lineage( + model_name, column, models_only, self.lsp_context.context + ) + return ApiResponseGetColumnLineage(data=column_lineage_response) + + raise NotImplementedError(f"API request not implemented: {request.url}") + + def _custom_supported_methods( + self, ls: LanguageServer, params: SupportedMethodsRequest + ) -> SupportedMethodsResponse: + """Return all supported custom LSP methods.""" + return SupportedMethodsResponse( + methods=[ + CustomMethod( + name=name, + ) + for name in self._supported_custom_methods + ] + ) + def _register_features(self) -> None: """Register LSP features on the internal LanguageServer instance.""" + for name, method in self._supported_custom_methods.items(): + + def create_function_call(method_func: t.Callable) -> t.Callable: + def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]: + try: + response = method_func(ls, params) + except Exception as e: + response = CustomMethodResponseBaseClass(response_error=str(e)) + return response.model_dump(mode="json") + + return function_call + + self.server.feature(name)(create_function_call(method)) @self.server.feature(types.INITIALIZE) def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: @@ -144,121 +272,6 @@ def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: f"Error initializing SQLMesh context: {e}", ) - @self.server.feature(ALL_MODELS_FEATURE) - def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse: - uri = URI(params.textDocument.uri) - - # Get the document content - content = None - try: - document = ls.workspace.get_text_document(params.textDocument.uri) - content = document.source - except Exception: - pass - - try: - context = self._context_get_or_load(uri) - return LSPContext.get_completions(context, uri, content) - except Exception as e: - from sqlmesh.lsp.completions import get_sql_completions - - return get_sql_completions(None, URI(params.textDocument.uri), content) - - @self.server.feature(RENDER_MODEL_FEATURE) - def render_model(ls: LanguageServer, params: RenderModelRequest) -> RenderModelResponse: - uri = URI(params.textDocumentUri) - context = self._context_get_or_load(uri) - return RenderModelResponse(models=context.render_model(uri)) - - @self.server.feature(ALL_MODELS_FOR_RENDER_FEATURE) - def all_models_for_render( - ls: LanguageServer, params: AllModelsForRenderRequest - ) -> AllModelsForRenderResponse: - if self.lsp_context is None: - current_path = Path.cwd() - self._ensure_context_in_folder(current_path) - if self.lsp_context is None: - raise RuntimeError("No context found") - return AllModelsForRenderResponse( - models=self.lsp_context.list_of_models_for_rendering() - ) - - @self.server.feature(SUPPORTED_METHODS_FEATURE) - def supported_methods( - ls: LanguageServer, params: SupportedMethodsRequest - ) -> SupportedMethodsResponse: - """Return all supported custom LSP methods.""" - return SupportedMethodsResponse( - methods=[ - CustomMethod( - name=name, - ) - for name in SUPPORTED_CUSTOM_METHODS - ] - ) - - @self.server.feature(FORMAT_PROJECT_FEATURE) - def format_project( - ls: LanguageServer, params: FormatProjectRequest - ) -> FormatProjectResponse: - """Format all models in the current project.""" - try: - if self.lsp_context is None: - current_path = Path.cwd() - self._ensure_context_in_folder(current_path) - if self.lsp_context is None: - raise RuntimeError("No context found") - - # Call the format method on the context - self.lsp_context.context.format() - return FormatProjectResponse() - except Exception as e: - ls.log_trace(f"Error formatting project: {e}") - return FormatProjectResponse() - - @self.server.feature(API_FEATURE) - def api(ls: LanguageServer, request: ApiRequest) -> t.Dict[str, t.Any]: - ls.log_trace(f"API request: {request}") - if self.lsp_context is None: - current_path = Path.cwd() - self._ensure_context_in_folder(current_path) - if self.lsp_context is None: - raise RuntimeError("No context found") - - parsed_url = urllib.parse.urlparse(request.url) - path_parts = parsed_url.path.strip("/").split("/") - - if request.method == "GET": - if path_parts == ["api", "models"]: - # /api/models - return ApiResponseGetModels( - data=get_models(self.lsp_context.context) - ).model_dump(mode="json") - - if path_parts[:2] == ["api", "lineage"]: - if len(path_parts) == 3: - # /api/lineage/{model} - model_name = urllib.parse.unquote(path_parts[2]) - lineage = model_lineage(model_name, self.lsp_context.context) - non_set_lineage = {k: v for k, v in lineage.items() if v is not None} - return ApiResponseGetLineage(data=non_set_lineage).model_dump(mode="json") - - if len(path_parts) == 4: - # /api/lineage/{model}/{column} - model_name = urllib.parse.unquote(path_parts[2]) - column = urllib.parse.unquote(path_parts[3]) - models_only = False - if hasattr(request, "params"): - models_only = bool(getattr(request.params, "models_only", False)) - column_lineage_response = column_lineage( - model_name, column, models_only, self.lsp_context.context - ) - return ApiResponseGetColumnLineage(data=column_lineage_response).model_dump( - mode="json" - ) - - raise NotImplementedError(f"API request not implemented: {request.url}") - @self.server.feature(types.TEXT_DOCUMENT_DID_OPEN) def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None: uri = URI(params.text_document.uri) diff --git a/tooling/vscode/extensions.json b/tooling/vscode/extensions.json index 9e5e0b733f..0271570408 100644 --- a/tooling/vscode/extensions.json +++ b/tooling/vscode/extensions.json @@ -4,6 +4,7 @@ "recommendations": [ "dbaeumer.vscode-eslint", "amodio.tsl-problem-matcher", - "ms-vscode.extension-test-runner" + "ms-vscode.extension-test-runner", + "ms-playwright.playwright" ] }