|
3 | 3 |
|
4 | 4 | import logging |
5 | 5 | import typing as t |
6 | | -import weakref |
7 | 6 | from contextlib import suppress |
8 | 7 | from pathlib import Path |
9 | 8 |
|
|
12 | 11 | from pygls.server import LanguageServer |
13 | 12 | from pygls.workspace import TextDocument |
14 | 13 | from sqlmesh._version import __version__ |
| 14 | +from sqlmesh.core.dialect import Audit |
15 | 15 | from sqlmesh.core.model import Model |
16 | 16 |
|
17 | 17 | logger = logging.getLogger(__name__) |
18 | 18 |
|
19 | | -CONTEXTS: t.Dict[str, Context] = {} |
20 | | -"""A mapping of workspace paths to SQLMesh contexts.""" |
| 19 | +GLOBAL_CONTEXT: t.Optional[ |
| 20 | + t.Tuple[ |
| 21 | + Context, |
| 22 | + t.Dict[str, t.Tuple[Context, t.Union[Model, Audit]]], |
| 23 | + ] |
| 24 | +] = None |
21 | 25 |
|
22 | | -PATHS_TO_MODELS: t.Dict[str, t.Tuple[Context, Model]] = {} |
23 | | -"""A mapping of file paths to SQLMesh (context, model) tuples.""" |
24 | 26 |
|
25 | 27 | server = LanguageServer("sqlmesh_lsp", __version__) |
26 | 28 |
|
27 | | -_CACHE: t.Set[str] = set() |
28 | | -"""A cache of URIs for which we have already ensured a context exists.""" |
29 | | - |
30 | 29 |
|
31 | 30 | def ensure_context_for_document(document: TextDocument) -> TextDocument: |
32 | | - """Ensure that a context exists for the given document if applicable.""" |
33 | | - if document.uri in _CACHE: |
34 | | - return document |
35 | | - _CACHE.add(document.uri) |
| 31 | + """Ensure that a context exists for the given document if applicable by searching for a config.py or config.yml file in the parent directories.""" |
| 32 | + # If the context is already loaded, return the document, if it is part of the same context |
| 33 | + global GLOBAL_CONTEXT |
| 34 | + if GLOBAL_CONTEXT is not None: |
| 35 | + GLOBAL_CONTEXT[0].load() |
| 36 | + if document.uri in GLOBAL_CONTEXT[1]: |
| 37 | + return document |
| 38 | + else: |
| 39 | + for model in GLOBAL_CONTEXT[0]._models.values(): |
| 40 | + if model._path is None: |
| 41 | + continue |
| 42 | + path = model._path.resolve() |
| 43 | + if path == document.path: |
| 44 | + GLOBAL_CONTEXT = ( |
| 45 | + GLOBAL_CONTEXT[0], |
| 46 | + GLOBAL_CONTEXT[1] | {document.uri: (GLOBAL_CONTEXT[0], model)}, |
| 47 | + ) |
| 48 | + return document |
| 49 | + for audit in GLOBAL_CONTEXT[0]._audits.values(): |
| 50 | + if audit._path is None: |
| 51 | + continue |
| 52 | + path = audit._path.resolve() |
| 53 | + if path == document.path: |
| 54 | + GLOBAL_CONTEXT = ( |
| 55 | + GLOBAL_CONTEXT[0], |
| 56 | + GLOBAL_CONTEXT[1] | {document.uri: (GLOBAL_CONTEXT[0], audit)}, |
| 57 | + ) |
| 58 | + return document |
| 59 | + return document |
| 60 | + |
| 61 | + # If there is no context, load the context and then call this function again |
36 | 62 | path = Path(document.path).resolve() |
37 | 63 | if path.suffix not in (".sql", ".py"): |
38 | 64 | return document |
39 | | - initial_path = path |
40 | | - while path.parents: |
41 | | - if str(path) in CONTEXTS: |
42 | | - return document |
43 | | - path = path.parent |
44 | | - path = initial_path |
45 | 65 | loaded = False |
46 | 66 | while path.parents and not loaded: |
47 | 67 | for ext in ("py", "yml", "yaml"): |
48 | 68 | config_path = path / f"config.{ext}" |
49 | 69 | if config_path.exists(): |
50 | 70 | with suppress(Exception): |
51 | | - handle = Context(paths=[f"{path}"]) |
52 | | - CONTEXTS[str(path)] = handle |
53 | | - PATHS_TO_MODELS.update( |
54 | | - { |
55 | | - str(model._path.resolve()): (handle, weakref.proxy(model)) |
56 | | - for model in handle.models.values() |
57 | | - } |
58 | | - ) |
| 71 | + handle = Context(paths=[path]) |
| 72 | + GLOBAL_CONTEXT = (handle, {}) |
59 | 73 | server.show_message(f"Context loaded for: {path}") |
60 | 74 | loaded = True |
61 | | - break |
| 75 | + return ensure_context_for_document(document) |
62 | 76 | path = path.parent |
| 77 | + |
63 | 78 | return document |
64 | 79 |
|
65 | 80 |
|
66 | 81 | @server.feature(types.TEXT_DOCUMENT_FORMATTING) |
67 | 82 | def formatting( |
68 | 83 | ls: LanguageServer, params: types.DocumentFormattingParams |
69 | 84 | ) -> t.List[types.TextEdit]: |
70 | | - """Format the document based using SQLMesh format_model_expressions.""" |
| 85 | + """Format the document using SQLMesh format_model_expressions.""" |
71 | 86 | try: |
72 | | - logger.info(f"Formatting document: {params.text_document.uri}") |
73 | 87 | document = ensure_context_for_document(ls.workspace.get_document(params.text_document.uri)) |
74 | | - context, _ = PATHS_TO_MODELS.get(document.path, (None, None)) |
| 88 | + context = GLOBAL_CONTEXT |
75 | 89 | if context is None: |
76 | | - logger.error(f"No context found for document: {document.path}") |
77 | | - return [] |
78 | | - context.format(paths=(Path(document.path),)) |
| 90 | + raise Exception(f"No context found for document: {document.path}") |
| 91 | + context[0].format(paths=(Path(document.path),)) |
79 | 92 | with open(document.path, "r+", encoding="utf-8") as file: |
80 | 93 | return [ |
81 | 94 | types.TextEdit( |
|
0 commit comments