Skip to content

Commit 5ba4c88

Browse files
committed
temp
[ci skip]
1 parent 980f371 commit 5ba4c88

4 files changed

Lines changed: 142 additions & 25 deletions

File tree

examples/sushi/models/disabled.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ MODEL (
33
enabled False,
44
);
55

6-
SELECT 1 AS a;
6+
SELECT 1 AS a ;

examples/sushi/models/top_waiters.sql

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ MODEL (
1010

1111
WITH test_macros AS (
1212
SELECT
13-
@ADD_ONE(1) AS lit_two,
13+
@ADD_ONE(1) AS lit_two,
1414
@MULTIPLY(revenue, 2.0) AS sql_exp,
1515
@SQL_LITERAL(revenue::TEXT, 'x', 'y', a, "b") AS sql_lit
1616
FROM sushi.waiter_revenue_by_day
1717
)
1818
SELECT
19-
waiter_id::INT AS waiter_id ,
19+
waiter_id::INT AS waiter_id,
2020
revenue::DOUBLE AS revenue
2121
FROM sushi.waiter_revenue_by_day
2222
WHERE
2323
event_date = (
24-
SELECT
24+
SELECT
2525
MAX(event_date)
2626
FROM sushi.waiter_revenue_by_day
2727
)

sqlmesh/lsp/main.py

Lines changed: 123 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#!/usr/bin/env python
22
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
3+
#!/usr/bin/env python
4+
"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration."""
35

46
import asyncio
57
import gc
8+
import io
69
import logging
710
import re
811
import typing as t
@@ -17,7 +20,11 @@
1720
from lsprotocol import types
1821
from pygls.server import LanguageServer
1922
from pygls.workspace import TextDocument
23+
from sqlglot.errors import ParseError
2024
from sqlmesh.core.dialect import format_model_expressions, parse
25+
from sqlmesh.core.lineage import column_description
26+
from sqlmesh.core.model import SqlModel
27+
from sqlmesh.utils import type_is_known
2128

2229
logger = logging.getLogger(__name__)
2330

@@ -34,7 +41,6 @@
3441
"""A locking mechanism for ensuring that context mutation is thread-safe."""
3542

3643
loop = asyncio.get_event_loop()
37-
3844
server = LanguageServer("sqlmesh-lsp", "v0.1.0", loop=loop)
3945

4046

@@ -93,9 +99,34 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
9399
return document
94100

95101

102+
@server.feature(types.TEXT_DOCUMENT_COMPLETION)
103+
async def completions(ls: LanguageServer, params: types.CompletionParams):
104+
"""Provide completions based on upstream model column information."""
105+
items = []
106+
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
107+
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
108+
if context is None or model is None:
109+
return types.CompletionList(is_incomplete=False, items=[])
110+
for dep in model.depends_on:
111+
model_dep = context.models[dep]
112+
if model_dep.columns_to_types:
113+
for column, type_ in model_dep.columns_to_types.items():
114+
items.append(
115+
types.CompletionItem(
116+
label=column,
117+
label_details=types.CompletionItemLabelDetails(detail=type_.sql()),
118+
documentation=f"Source: {dep}\n\n"
119+
+ (column_description(context, dep, column) or "No description available"),
120+
kind=types.CompletionItemKind.Field,
121+
)
122+
)
123+
return types.CompletionList(is_incomplete=False, items=items)
124+
125+
96126
@server.feature(types.TEXT_DOCUMENT_FORMATTING)
97-
async def formatting(ls: LanguageServer, params: types.DocumentFormattingParams):
127+
async def formatting(ls: LanguageServer, params: types.DocumentFormattingParams) -> t.List[types.TextEdit]:
98128
"""Format the document based using SQLMesh format_model_expressions."""
129+
logger.info(f"Formatting document: {params.text_document.uri}")
99130
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
100131
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
101132
if context is None or model is None:
@@ -104,7 +135,8 @@ async def formatting(ls: LanguageServer, params: types.DocumentFormattingParams)
104135
dialect = model.dialect if model and model.is_sql else default_dialect
105136
try:
106137
expressions = parse(document.source, default_dialect=dialect)
107-
except Exception:
138+
except Exception as e:
139+
logger.error(f"Exception occurred while parsing document: {e}")
108140
return []
109141
try:
110142
fmt_doc = format_model_expressions(expressions, dialect, **context.config.format.generator_options)
@@ -152,6 +184,66 @@ def _iter_match_ranges_in_projection(term: str, source: str):
152184
)
153185

154186

187+
def _update_diagnostics(document: TextDocument) -> None:
188+
"""Update diagnostics for the given document."""
189+
WORKSPACE_DIAGNOSTICS[document.uri] = (document.version, diagnostics := [])
190+
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
191+
if context is None or model is None:
192+
return
193+
194+
default_dialect = context.default_dialect
195+
dialect = model.dialect if model and model.is_sql else default_dialect
196+
197+
try:
198+
_ = parse(document.source, default_dialect=dialect)
199+
except ParseError as e:
200+
for error in e.errors:
201+
line = error["line"]
202+
comments_before_line = [
203+
_l for _l in document.lines[:line] if _l.strip().startswith(("/*", "--"))
204+
] # This is just a hack to adjust the line number, not a proper solution but it works
205+
line -= len(comments_before_line)
206+
diagnostics.append(
207+
types.Diagnostic(
208+
message=e.args[0],
209+
severity=types.DiagnosticSeverity.Error,
210+
range=types.Range(
211+
start=types.Position(line=line, character=error["col"]),
212+
end=types.Position(line=line, character=error["col"]),
213+
),
214+
)
215+
)
216+
217+
if model is not None and isinstance(model, SqlModel):
218+
sqlmesh_renderer_logger = logging.getLogger("sqlmesh.core.renderer")
219+
buf = io.StringIO()
220+
interceptor = logging.StreamHandler(stream=buf)
221+
interceptor.setLevel(logging.WARNING)
222+
interceptor.setFormatter(logging.Formatter("%(message)s"))
223+
sqlmesh_renderer_logger.addHandler(interceptor)
224+
_ = model._query_renderer.render(execution_time="now")
225+
sqlmesh_renderer_logger.removeHandler(interceptor)
226+
buf.seek(0)
227+
warnings = buf.read().strip()
228+
if warnings:
229+
diagnostics.append(
230+
types.Diagnostic(message=warnings, severity=types.DiagnosticSeverity.Warning, range=_top_of_file)
231+
)
232+
233+
setattr(model, "_columns_to_types", None) # clear cached columns to types
234+
if model and model.columns_to_types:
235+
for column, type_ in model.columns_to_types.items():
236+
if not type_is_known(type_):
237+
for range_ in _iter_match_ranges_in_projection(column, document.source):
238+
diagnostics.append(
239+
types.Diagnostic(
240+
message=f"Unknown type for column: {column} - add a type hint to final projection",
241+
severity=types.DiagnosticSeverity.Warning,
242+
range=range_,
243+
)
244+
)
245+
246+
155247
@server.feature(types.TEXT_DOCUMENT_DID_OPEN)
156248
async def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
157249
"""Update diagnostics on document open and refresh context if necessary."""
@@ -166,6 +258,16 @@ async def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
166258
PATHS_TO_MODELS.update(
167259
{str(model._path.resolve()): (context, weakref.proxy(model)) for model in context.models.values()}
168260
)
261+
_update_diagnostics(document)
262+
for uri, (version, diagnostics) in WORKSPACE_DIAGNOSTICS.items():
263+
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)
264+
265+
266+
@server.feature(types.TEXT_DOCUMENT_DID_CLOSE)
267+
async def did_close(ls: LanguageServer, params: types.DidCloseTextDocumentParams):
268+
"""Remove diagnostics on document close."""
269+
if params.text_document.uri in WORKSPACE_DIAGNOSTICS:
270+
del WORKSPACE_DIAGNOSTICS[params.text_document.uri]
169271

170272

171273
@server.feature(types.TEXT_DOCUMENT_DID_SAVE)
@@ -181,6 +283,18 @@ async def did_save(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
181283
if model._path == Path(document.path):
182284
PATHS_TO_MODELS[document.path] = (context, weakref.proxy(model))
183285
break
286+
_update_diagnostics(document)
287+
for uri, (version, diagnostics) in WORKSPACE_DIAGNOSTICS.items():
288+
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)
289+
290+
291+
@server.feature(types.TEXT_DOCUMENT_DID_CHANGE)
292+
async def did_change(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
293+
"""Update diagnostics on document change."""
294+
document = await ensure_context_for_document(ls.workspace.get_text_document(params.text_document.uri))
295+
_update_diagnostics(document)
296+
for uri, (version, diagnostics) in WORKSPACE_DIAGNOSTICS.items():
297+
ls.publish_diagnostics(uri=uri, version=version, diagnostics=diagnostics)
184298

185299

186300
@server.feature(types.WORKSPACE_DID_CHANGE_WATCHED_FILES)
@@ -189,6 +303,9 @@ async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWa
189303
updated = {}
190304
for change in params.changes:
191305
document = await ensure_context_for_document(ls.workspace.get_text_document(change.uri))
306+
if change.type == types.FileChangeType.Changed:
307+
_update_diagnostics(document)
308+
continue
192309
path = Path(document.path)
193310
known_paths = PATHS_TO_MODELS.keys()
194311
if change.type == types.FileChangeType.Deleted and str(path) in known_paths:
@@ -215,8 +332,11 @@ async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWa
215332
)
216333
updated[context.path] = True
217334

335+
218336
def main():
337+
logging.basicConfig(level=logging.DEBUG)
219338
server.start_io()
220339

340+
221341
if __name__ == "__main__":
222342
main()

vscode/extension/src/lsp/lsp.ts

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import path from "path";
2-
import { workspace, ExtensionContext } from "vscode";
2+
import { workspace, ExtensionContext, OutputChannel, window } from "vscode";
33
import { ServerOptions, LanguageClientOptions, LanguageClient, TransportKind } from "vscode-languageclient/node";
44
import { sqlmesh_exec, sqlmesh_lsp_exec } from "../sqlmesh/sqlmesh";
55
import { err, isErr, ok, Result } from "../functional/result";
@@ -16,45 +16,42 @@ export async function activateLsp(context: ExtensionContext): Promise<Result<und
1616
if (workspaceFolders.length !== 1) {
1717
return err("Invalid number of workspace folders")
1818
}
19+
const outputChannel: OutputChannel = window.createOutputChannel('sqlmesh_actual_lsp_implementation');
20+
21+
let folder = workspaceFolders[0]
1922
const workspacePath = workspaceFolders[0].uri.fsPath
2023
let serverOptions: ServerOptions = {
2124
run: {
2225
command: sqlmesh.value.bin,
2326
transport: TransportKind.stdio,
2427
options: {
25-
env: {
26-
...process.env,
27-
},
2828
cwd: workspacePath,
29-
}
29+
},
3030
},
3131
debug: {
3232
command: sqlmesh.value.bin,
3333
transport: TransportKind.stdio,
3434
options: {
35-
env: {
36-
...process.env,
37-
},
3835
cwd: workspacePath,
3936
}
4037
}
4138
}
4239
let clientOptions: LanguageClientOptions = {
4340
documentSelector: [
44-
{
45-
scheme: 'file',
46-
language: 'sql',
47-
pattern: '**/*.sql'
48-
}
41+
{ scheme: 'file', pattern: `**/*.sql` }
4942
],
50-
synchronize: {
51-
fileEvents: workspace.createFileSystemWatcher('**/*.{sql,py}'),
52-
}
43+
workspaceFolder: folder,
44+
diagnosticCollectionName: 'sqlmesh',
45+
outputChannel: outputChannel,
46+
// synchronize: {
47+
// fileEvents: workspace.createFileSystemWatcher('**/*.{sql,py}'),
48+
// }
5349
}
5450

55-
client = new LanguageClient('sqlmesh', 'SQLMesh Language Server', serverOptions, clientOptions)
51+
client = new LanguageClient('sqlmesh-lsp-example', 'SQLMesh Language Server', serverOptions, clientOptions)
5652
console.log('Starting language client')
57-
await client.start()
53+
client.start()
54+
5855
console.log('Language client started')
5956
return ok(undefined)
6057
}

0 commit comments

Comments
 (0)