Skip to content

Commit 04afbe6

Browse files
committed
temp
[ci skip]
1 parent 980f371 commit 04afbe6

4 files changed

Lines changed: 86 additions & 102 deletions

File tree

examples/sushi/models/disabled.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ MODEL (
33
enabled False,
44
);
55

6-
SELECT 1 AS a;
6+
SELECT 1 AS a ;

examples/sushi/models/top_waiters.sql

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ MODEL (
1010

1111
WITH test_macros AS (
1212
SELECT
13-
@ADD_ONE(1) AS lit_two,
13+
@ADD_ONE(1) AS lit_two,
1414
@MULTIPLY(revenue, 2.0) AS sql_exp,
1515
@SQL_LITERAL(revenue::TEXT, 'x', 'y', a, "b") AS sql_lit
1616
FROM sushi.waiter_revenue_by_day
1717
)
1818
SELECT
19-
waiter_id::INT AS waiter_id ,
19+
waiter_id::INT AS waiter_id,
2020
revenue::DOUBLE AS revenue
2121
FROM sushi.waiter_revenue_by_day
2222
WHERE
2323
event_date = (
24-
SELECT
24+
SELECT
2525
MAX(event_date)
2626
FROM sushi.waiter_revenue_by_day
2727
)

sqlmesh/lsp/main.py

Lines changed: 67 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,33 @@
44
import asyncio
55
import gc
66
import logging
7-
import re
87
import typing as t
98
import weakref
109
from collections import defaultdict
1110
from contextlib import suppress
12-
from functools import lru_cache
1311
from itertools import cycle
1412
from pathlib import Path
1513

1614
import sqlmesh
1715
from lsprotocol import types
1816
from pygls.server import LanguageServer
1917
from pygls.workspace import TextDocument
20-
from sqlmesh.core.dialect import format_model_expressions, parse
18+
from sqlmesh._version import __version__
2119

2220
logger = logging.getLogger(__name__)
2321

24-
WORKSPACE_DIAGNOSTICS: t.Dict[str, t.Tuple[t.Optional[int], t.List[types.Diagnostic]]] = {}
25-
"""A mapping of document URIs to diagnostics."""
26-
2722
CONTEXTS: t.Dict[str, sqlmesh.Context] = {}
2823
"""A mapping of workspace paths to SQLMesh contexts."""
2924

3025
PATHS_TO_MODELS: t.Dict[str, t.Tuple[sqlmesh.Context, sqlmesh.Model]] = {}
3126
"""A mapping of file paths to SQLMesh (context, model) tuples."""
3227

33-
C_MUTEX = defaultdict(asyncio.Lock)
28+
C_MUTEX: t.DefaultDict[t.Union[str, Path], asyncio.Lock] = defaultdict(asyncio.Lock)
3429
"""A locking mechanism for ensuring that context mutation is thread-safe."""
3530

3631
loop = asyncio.get_event_loop()
3732

38-
server = LanguageServer("sqlmesh-lsp", "v0.1.0", loop=loop)
33+
server = LanguageServer("sqlmesh_lsp", __version__, loop=loop)
3934

4035

4136
async def refresh_context_loop(context: sqlmesh.Context) -> None:
@@ -46,11 +41,15 @@ async def refresh_context_loop(context: sqlmesh.Context) -> None:
4641
gc_iter = cycle(list(range(10)))
4742
while True:
4843
await asyncio.sleep(10.0)
49-
if context._loader.reload_needed():
50-
async with C_MUTEX[context.path]:
51-
await asyncio.to_thread(context.load)
44+
for loader in context._loaders:
45+
if loader.reload_needed():
46+
async with C_MUTEX[context.path]:
47+
await asyncio.to_thread(context.load)
5248
PATHS_TO_MODELS.update(
53-
{str(model._path.resolve()): (context, weakref.proxy(model)) for model in context.models.values()}
49+
{
50+
str(model._path.resolve()): (context, weakref.proxy(model))
51+
for model in context.models.values()
52+
}
5453
)
5554
if next(gc_iter) == 0:
5655
gc.collect()
@@ -84,7 +83,10 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
8483
loop.create_task(refresh_context_loop(handle))
8584
CONTEXTS[str(path)] = handle
8685
PATHS_TO_MODELS.update(
87-
{str(model._path.resolve()): (handle, weakref.proxy(model)) for model in handle.models.values()}
86+
{
87+
str(model._path.resolve()): (handle, weakref.proxy(model))
88+
for model in handle.models.values()
89+
}
8890
)
8991
server.show_message(f"Context loaded for: {path}")
9092
loaded = True
@@ -94,99 +96,81 @@ async def ensure_context_for_document(document: TextDocument) -> TextDocument:
9496

9597

9698
@server.feature(types.TEXT_DOCUMENT_FORMATTING)
97-
async def formatting(ls: LanguageServer, params: types.DocumentFormattingParams):
99+
async def formatting(
100+
ls: LanguageServer, params: types.DocumentFormattingParams
101+
) -> t.List[types.TextEdit]:
98102
"""Format the document based using SQLMesh format_model_expressions."""
99-
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
100-
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
101-
if context is None or model is None:
102-
return []
103-
default_dialect = context.default_dialect
104-
dialect = model.dialect if model and model.is_sql else default_dialect
105-
try:
106-
expressions = parse(document.source, default_dialect=dialect)
107-
except Exception:
108-
return []
109103
try:
110-
fmt_doc = format_model_expressions(expressions, dialect, **context.config.format.generator_options)
111-
if context.config.format.append_newline:
112-
fmt_doc += "\n"
104+
logger.info(f"Formatting document: {params.text_document.uri}")
105+
document = await ensure_context_for_document(
106+
ls.workspace.get_document(params.text_document.uri)
107+
)
108+
context, model = PATHS_TO_MODELS.get(document.path, (None, None))
109+
context.format(paths=[Path(document.path)])
110+
with open(document.path, "r+", encoding="utf-8") as file:
111+
return [
112+
types.TextEdit(
113+
range=types.Range(
114+
types.Position(0, 0),
115+
types.Position(len(document.lines), len(document.lines[-1])),
116+
),
117+
new_text=file.read(),
118+
)
119+
]
113120
except Exception as e:
114121
ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error)
115122
return []
116-
return [
117-
types.TextEdit(
118-
range=types.Range(
119-
types.Position(0, 0),
120-
types.Position(len(document.lines), len(document.lines[-1])),
121-
),
122-
new_text=fmt_doc,
123-
)
124-
]
125-
126-
127-
_top_of_file = types.Range(start=types.Position(line=0, character=0), end=types.Position(line=0, character=0))
128-
129-
_cached_re_compile = t.cast(t.Callable[[str, re.RegexFlag], re.Pattern[str]], lru_cache(maxsize=1024)(re.compile))
130-
131-
132-
def _iter_match_ranges_in_projection(term: str, source: str):
133-
"""Iterate over ranges of matches for a term in a SQL projection."""
134-
col_patt = _cached_re_compile(rf'\b["`]?({term})["`]?,?', re.IGNORECASE)
135-
projection_patt = _cached_re_compile(r"SELECT\s+(.*)\s+FROM", re.DOTALL | re.IGNORECASE)
136-
for p_match in projection_patt.finditer(source):
137-
if not p_match.group(1):
138-
continue
139-
proj_start, proj_end = p_match.span(1)
140-
proj_substr = source[proj_start:proj_end]
141-
for c_match in col_patt.finditer(proj_substr):
142-
if not c_match.group(1):
143-
continue
144-
col_start, col_end = c_match.span(1)
145-
start, end = proj_start + col_start, proj_start + col_end
146-
line = source.count("\n", 0, start)
147-
char_s = start - source.rfind("\n", 0, start) - 1
148-
char_e = end - source.rfind("\n", 0, end)
149-
yield types.Range(
150-
start=types.Position(line=line, character=char_s),
151-
end=types.Position(line=line, character=char_e),
152-
)
153123

154124

155125
@server.feature(types.TEXT_DOCUMENT_DID_OPEN)
156-
async def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
126+
async def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
157127
"""Update diagnostics on document open and refresh context if necessary."""
158-
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
128+
document = await ensure_context_for_document(
129+
ls.workspace.get_document(params.text_document.uri)
130+
)
159131
path = Path(document.path)
160132
known_paths = PATHS_TO_MODELS.keys()
161133
for context in CONTEXTS.values():
162-
if path.is_relative_to(context.path) and path.suffix in (".sql", ".py") and str(path) not in known_paths:
134+
if (
135+
path.is_relative_to(context.path)
136+
and path.suffix in (".sql", ".py")
137+
and str(path) not in known_paths
138+
):
163139
ls.show_message(f"Refreshing context with new file: {path}", types.MessageType.Info)
164140
async with C_MUTEX[context.path]:
165141
await asyncio.to_thread(context.load)
166142
PATHS_TO_MODELS.update(
167-
{str(model._path.resolve()): (context, weakref.proxy(model)) for model in context.models.values()}
143+
{
144+
str(model._path.resolve()): (context, weakref.proxy(model))
145+
for model in context.models.values()
146+
}
168147
)
169148

170149

171150
@server.feature(types.TEXT_DOCUMENT_DID_SAVE)
172-
async def did_save(ls: LanguageServer, params: types.DidOpenTextDocumentParams):
151+
async def did_save(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
173152
"""Update diagnostics on document save."""
174-
document = await ensure_context_for_document(ls.workspace.get_document(params.text_document.uri))
153+
document = await ensure_context_for_document(
154+
ls.workspace.get_document(params.text_document.uri)
155+
)
175156
context, _ = PATHS_TO_MODELS.get(document.path, (None, None))
176157
if context is not None:
177-
context._loader._path_mtimes[Path(document.path)] = 0.0
178-
async with C_MUTEX[context.path]:
179-
await asyncio.to_thread(context.load)
180-
for model in context.models.values():
181-
if model._path == Path(document.path):
182-
PATHS_TO_MODELS[document.path] = (context, weakref.proxy(model))
183-
break
158+
for loader in context._loaders:
159+
loader._path_mtimes[Path(document.path)] = 0.0
160+
async with C_MUTEX[context.path]:
161+
await asyncio.to_thread(context.load)
162+
for model in context.models.values():
163+
if model._path == Path(document.path):
164+
PATHS_TO_MODELS[document.path] = (context, weakref.proxy(model))
165+
break
184166

185167

186168
@server.feature(types.WORKSPACE_DID_CHANGE_WATCHED_FILES)
187-
async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWatchedFilesParams):
169+
async def did_change_watched_files(
170+
ls: LanguageServer, params: types.DidChangeWatchedFilesParams
171+
) -> None:
188172
"""Refresh context if a file changes."""
189-
updated = {}
173+
updated: t.Dict[t.Union[str, Path], bool] = {}
190174
for change in params.changes:
191175
document = await ensure_context_for_document(ls.workspace.get_text_document(change.uri))
192176
path = Path(document.path)
@@ -215,8 +199,11 @@ async def did_change_watched_files(ls: LanguageServer, params: types.DidChangeWa
215199
)
216200
updated[context.path] = True
217201

218-
def main():
202+
203+
def main() -> None:
204+
logging.basicConfig(level=logging.DEBUG)
219205
server.start_io()
220206

207+
221208
if __name__ == "__main__":
222209
main()

vscode/extension/src/lsp/lsp.ts

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import path from "path";
2-
import { workspace, ExtensionContext } from "vscode";
2+
import { workspace, ExtensionContext, OutputChannel, window } from "vscode";
33
import { ServerOptions, LanguageClientOptions, LanguageClient, TransportKind } from "vscode-languageclient/node";
44
import { sqlmesh_exec, sqlmesh_lsp_exec } from "../sqlmesh/sqlmesh";
55
import { err, isErr, ok, Result } from "../functional/result";
@@ -16,45 +16,42 @@ export async function activateLsp(context: ExtensionContext): Promise<Result<und
1616
if (workspaceFolders.length !== 1) {
1717
return err("Invalid number of workspace folders")
1818
}
19+
const outputChannel: OutputChannel = window.createOutputChannel('sqlmesh_actual_lsp_implementation');
20+
21+
let folder = workspaceFolders[0]
1922
const workspacePath = workspaceFolders[0].uri.fsPath
2023
let serverOptions: ServerOptions = {
2124
run: {
2225
command: sqlmesh.value.bin,
2326
transport: TransportKind.stdio,
2427
options: {
25-
env: {
26-
...process.env,
27-
},
2828
cwd: workspacePath,
29-
}
29+
},
3030
},
3131
debug: {
3232
command: sqlmesh.value.bin,
3333
transport: TransportKind.stdio,
3434
options: {
35-
env: {
36-
...process.env,
37-
},
3835
cwd: workspacePath,
3936
}
4037
}
4138
}
4239
let clientOptions: LanguageClientOptions = {
4340
documentSelector: [
44-
{
45-
scheme: 'file',
46-
language: 'sql',
47-
pattern: '**/*.sql'
48-
}
41+
{ scheme: 'file', pattern: `**/*.sql` }
4942
],
50-
synchronize: {
51-
fileEvents: workspace.createFileSystemWatcher('**/*.{sql,py}'),
52-
}
43+
workspaceFolder: folder,
44+
diagnosticCollectionName: 'sqlmesh',
45+
outputChannel: outputChannel,
46+
// synchronize: {
47+
// fileEvents: workspace.createFileSystemWatcher('**/*.{sql,py}'),
48+
// }
5349
}
5450

55-
client = new LanguageClient('sqlmesh', 'SQLMesh Language Server', serverOptions, clientOptions)
51+
client = new LanguageClient('sqlmesh-lsp-example', 'SQLMesh Language Server', serverOptions, clientOptions)
5652
console.log('Starting language client')
57-
await client.start()
53+
client.start()
54+
5855
console.log('Language client started')
5956
return ok(undefined)
6057
}

0 commit comments

Comments
 (0)