Skip to content

Commit 169d531

Browse files
committed
feat: add model autocomplete to vscode
[ci skip]
1 parent e6cbce6 commit 169d531

8 files changed

Lines changed: 146 additions & 102 deletions

File tree

sqlmesh/lsp/custom.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from lsprotocol import types
2+
import typing as t
3+
from sqlmesh.utils.pydantic import PydanticModel
4+
5+
ALL_MODELS_FEATURE = 'sqlmesh/all_models'
6+
7+
class AllModelsRequest(PydanticModel):
8+
"""
9+
Request to get all the models that are in the current project.
10+
"""
11+
textDocument: types.TextDocumentIdentifier
12+
13+
class AllModelsResponse(PydanticModel):
14+
"""
15+
Response to get all the models that are in the current project.
16+
"""
17+
models: t.List[str]
18+

sqlmesh/lsp/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sqlmesh.core.context import Context
1313
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
1414
from sqlmesh.lsp.context import LSPContext
15+
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
1516
from sqlmesh.lsp.reference import get_model_definitions_for_a_path
1617

1718

@@ -38,6 +39,12 @@ def __init__(
3839
def _register_features(self) -> None:
3940
"""Register LSP features on the internal LanguageServer instance."""
4041

42+
@self.server.feature(ALL_MODELS_FEATURE)
43+
def all_models(ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse:
44+
context = self._context_get_or_load(params.textDocument.uri)
45+
models = context.context.models
46+
return AllModelsResponse(models=[model.name for model in models])
47+
4148
@self.server.feature(types.TEXT_DOCUMENT_DID_OPEN)
4249
def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None:
4350
context = self._context_get_or_load(params.text_document.uri)

sqlmesh/lsp/reference.py

Lines changed: 41 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@ class Reference(PydanticModel):
1515
range: Range
1616
uri: str
1717

18-
19-
def get_model_definitions_for_a_path(
20-
lint_context: LSPContext, document_uri: str
21-
) -> t.List[Reference]:
18+
def get_model_definitions_for_a_path(lint_context: LSPContext, document_uri: str) -> t.List[Reference]:
2219
"""
2320
Get the model references for a given path.
2421
@@ -36,41 +33,38 @@ def get_model_definitions_for_a_path(
3633
# Ensure the path is a sql model
3734
if not document_uri.endswith(".sql"):
3835
return []
39-
40-
# Get the model
36+
37+
# Get the model
4138
models = lint_context.map[document_uri]
42-
if models is None:
39+
if models is None:
4340
return []
4441
if len(models) == 0:
4542
return []
4643
model_name = models[0]
4744
model = lint_context.context.get_model(model_or_snapshot=model_name, raise_if_missing=False)
48-
if model is None:
45+
if model is None:
4946
return []
5047
if not isinstance(model, SqlModel):
5148
return []
52-
49+
5350
# Find all possible references
5451
tables = list(model.query.find_all(exp.Table))
5552
if len(tables) == 0:
5653
return []
57-
54+
5855
references = []
5956
for table in tables:
6057
depends_on = model.depends_on
6158

6259
# Normalize the table reference
6360
reference_name = table.this.this if table.db is None else f"{table.db}.{table.this.this}"
64-
normalized_reference_name = normalize_model_name(
65-
reference_name, default_catalog=lint_context.context.default_catalog
66-
)
61+
normalized_reference_name = normalize_model_name(reference_name, default_catalog=lint_context.context.default_catalog)
6762
if normalized_reference_name not in depends_on:
6863
continue
6964

65+
7066
# Get the referenced model uri
71-
referenced_model = lint_context.context.get_model(
72-
model_or_snapshot=normalized_reference_name, raise_if_missing=False
73-
)
67+
referenced_model = lint_context.context.get_model(model_or_snapshot=normalized_reference_name, raise_if_missing=False)
7468
if referenced_model is None:
7569
continue
7670
# Get the model uri
@@ -79,96 +73,42 @@ def get_model_definitions_for_a_path(
7973
continue
8074
# Fully qualify the path in case
8175
path = Path.resolve(Path(referenced_model_path))
82-
referenced_model_uri = f"file://{path}"
83-
read_file = open(path, "r").readlines()
84-
76+
referenced_model_path = f"file://{path}"
77+
# Get the path to the file containing the reference
78+
# file_path = document_uri.removeprefix("file://")
79+
# read_file = open(file_path, "r").readlines() # Reading the file here is not needed for range calculation
80+
8581
# Extract metadata for positioning
86-
table_meta = TokenPositionDetails.from_meta(table.this.meta)
87-
table_range = _range_from_token_position_details(table_meta, read_file)
88-
start_pos = table_range.start
89-
end_pos = table_range.end
90-
82+
this_meta = table.this.meta
83+
this_start_line_0 = this_meta['line'] - 1 # Convert to 0-indexed
84+
this_start_col_0 = this_meta['col'] - 1 # Convert to 0-indexed
85+
86+
# End position: Based on the start of 'this' token + its length
87+
end_char_0 = this_start_col_0 + len(table.this.this)
88+
end_pos = Position(line=this_start_line_0, character=end_char_0)
89+
90+
# Start position: Initially set to the start of 'this' token
91+
start_pos = Position(line=this_start_line_0, character=this_start_col_0)
92+
9193
# If there's a database qualifier, adjust the start position
92-
db = table.args.get("db")
94+
db = table.args.get('db')
9395
if db is not None:
94-
db_meta = TokenPositionDetails.from_meta(db.meta)
95-
db_range = _range_from_token_position_details(db_meta, read_file)
96-
start_pos = db_range.start
97-
96+
db_meta = db.meta
97+
db_start_line_0 = db_meta['line'] - 1 # Convert to 0-indexed
98+
db_start_col_0 = db_meta['col'] - 1 # Convert to 0-indexed
99+
start_pos = Position(line=db_start_line_0, character=db_start_col_0)
100+
98101
# If there's a catalog qualifier, adjust the start position further
99-
catalog = table.args.get("catalog")
102+
catalog = table.args.get('catalog')
100103
if catalog is not None:
101-
catalog_meta = TokenPositionDetails.from_meta(catalog.meta)
102-
catalog_range = _range_from_token_position_details(catalog_meta, read_file)
103-
start_pos = catalog_range.start
104+
catalog_meta = catalog.meta
105+
catalog_start_line_0 = catalog_meta['line'] - 1 # Convert to 0-indexed
106+
catalog_start_col_0 = catalog_meta['col'] - 1 # Convert to 0-indexed
107+
start_pos = Position(line=catalog_start_line_0, character=catalog_start_col_0)
104108

105-
references.append(
106-
Reference(uri=referenced_model_uri, range=Range(start=start_pos, end=end_pos))
107-
)
109+
references.append(Reference(
110+
uri=referenced_model_path,
111+
range=Range(start=start_pos, end=end_pos)
112+
))
108113

109114
return references
110-
111-
112-
class TokenPositionDetails(PydanticModel):
113-
"""
114-
Details about a token's position in the source code.
115-
116-
Attributes:
117-
line (int): The line that the token ends on.
118-
col (int): The column that the token ends on.
119-
start (int): The start index of the token.
120-
end (int): The ending index of the token.
121-
"""
122-
123-
line: int
124-
col: int
125-
start: int
126-
end: int
127-
128-
@staticmethod
129-
def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
130-
return TokenPositionDetails(
131-
line=meta["line"],
132-
col=meta["col"],
133-
start=meta["start"],
134-
end=meta["end"],
135-
)
136-
137-
138-
def _range_from_token_position_details(
139-
token_position_details: TokenPositionDetails, read_file: t.List[str]
140-
) -> Range:
141-
"""
142-
Convert a TokenPositionDetails object to a Range object.
143-
144-
:param token_position_details: Details about a token's position
145-
:param read_file: List of lines from the file
146-
:return: A Range object representing the token's position
147-
"""
148-
# Convert from 1-indexed to 0-indexed for line and column
149-
end_line_0 = token_position_details.line - 1
150-
end_col_0 = token_position_details.col
151-
152-
# Find the start line and column by counting backwards from the end position
153-
start_pos = token_position_details.start
154-
end_pos = token_position_details.end
155-
156-
# Initialize with the end position
157-
start_line_0 = end_line_0
158-
start_col_0 = end_col_0 - (end_pos - start_pos + 1)
159-
160-
# If start_col_0 is negative, we need to go back to previous lines
161-
while start_col_0 < 0 and start_line_0 > 0:
162-
start_line_0 -= 1
163-
start_col_0 += len(read_file[start_line_0])
164-
# Account for newline character
165-
if start_col_0 >= 0:
166-
break
167-
start_col_0 += 1 # For the newline character
168-
169-
# Ensure we don't have negative values
170-
start_col_0 = max(0, start_col_0)
171-
return Range(
172-
start=Position(line=start_line_0, character=start_col_0),
173-
end=Position(line=end_line_0, character=end_col_0),
174-
)

vscode/extension/package.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@
5757
"command": "sqlmesh.signout",
5858
"title": "Sign out from Tobiko Cloud",
5959
"description": "SQLMesh"
60+
},
61+
{
62+
"command": "sqlmesh.allModels",
63+
"title": "Get all models",
64+
"description": "SQLMesh"
6065
}
6166
]
6267
},
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
const selector: vscode.DocumentSelector = {
2+
pattern: '**/my-special-file.json' // ← only this file (you can use any glob)
3+
};
4+
5+
const provider: vscode.CompletionItemProvider = {
6+
provideCompletionItems(document, position) {
7+
const items: vscode.CompletionItem[] = [];
8+
9+
const item = new vscode.CompletionItem('mySuggestion', vscode.CompletionItemKind.Keyword);
10+
item.detail = 'Inserted by my extension';
11+
item.insertText = 'mySuggestion';
12+
items.push(item);
13+
14+
return items;
15+
}
16+
};
17+

vscode/extension/src/extension.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ export async function activate(context: vscode.ExtensionContext) {
8282
context.subscriptions.push(lspClient)
8383
}
8484

85+
ctx.subscriptions.push(
86+
vscode.languages.registerCompletionItemProvider(selector, provider, ...['.', ':'])
87+
);
88+
8589
const restart = async () => {
8690
if (lspClient) {
8791
traceVerbose('Restarting LSP client')

vscode/extension/src/lsp/custom.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
export type AllModelsMethod = {
2+
method: 'sqlmesh/all_models';
3+
request: AllModelsRequest;
4+
response: AllModelsResponse;
5+
};
6+
7+
export type TestMethod = {
8+
method: 'sqlmesh/test';
9+
request: TestRequest;
10+
response: TestResponse;
11+
};
12+
13+
export type CustomLSPMethods = AllModelsMethod | TestMethod;
14+
15+
interface AllModelsRequest {
16+
textDocument: {
17+
uri: string
18+
}
19+
}
20+
21+
interface AllModelsResponse {
22+
models: string[]
23+
}
24+
25+
interface TestRequest {
26+
foo: string
27+
}
28+
29+
interface TestResponse {
30+
bar: string
31+
}

vscode/extension/src/lsp/lsp.ts

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { err, isErr, ok, Result } from '../utilities/functional/result'
1010
import { getWorkspaceFolders } from '../utilities/common/vscodeapi'
1111
import { traceError } from '../utilities/common/log'
1212
import { ErrorType } from '../utilities/errors'
13+
import { CustomLSPMethods } from './custom'
1314

1415
let outputChannel: OutputChannel | undefined
1516

@@ -27,7 +28,9 @@ export class LSPClient implements Disposable {
2728

2829
const sqlmesh = await sqlmesh_lsp_exec()
2930
if (isErr(sqlmesh)) {
30-
traceError(`Failed to get sqlmesh_lsp_exec, ${JSON.stringify(sqlmesh.error)}`)
31+
traceError(
32+
`Failed to get sqlmesh_lsp_exec, ${JSON.stringify(sqlmesh.error)}`,
33+
)
3134
return sqlmesh
3235
}
3336
const workspaceFolders = getWorkspaceFolders()
@@ -96,4 +99,23 @@ export class LSPClient implements Disposable {
9699
public async dispose() {
97100
await this.stop()
98101
}
102+
103+
public async call_custom_method<
104+
Method extends CustomLSPMethods['method'],
105+
Request extends Extract<CustomLSPMethods, { method: Method }>['request'],
106+
Response extends Extract<CustomLSPMethods, { method: Method }>['response'],
107+
>(method: Method, request: Request): Promise<Result<Response, string>> {
108+
if (!this.client) {
109+
return err('lsp client not ready')
110+
}
111+
try {
112+
const result = await this.client.sendRequest<Response>(method, request)
113+
return ok(result)
114+
} catch (error) {
115+
traceError(
116+
`lsp '${method}' request ${JSON.stringify(request)} failed: ${JSON.stringify(error)}`,
117+
)
118+
return err(JSON.stringify(error))
119+
}
120+
}
99121
}

0 commit comments

Comments
 (0)