|
15 | 15 |
|
16 | 16 | logger = logging.getLogger(__name__) |
17 | 17 |
|
18 | | -CONTEXTS: t.Dict[str, Context] = {} |
19 | | -"""A mapping of workspace paths to SQLMesh contexts.""" |
| 18 | +CONTEXT: t.Optional[ |
| 19 | + t.Tuple[ |
| 20 | + Context, |
| 21 | + t.Dict[str, t.Tuple[Context, Model]], |
| 22 | + ] |
| 23 | +] = None |
20 | 24 |
|
21 | | -PATHS_TO_MODELS: t.Dict[str, t.Tuple[Context, Model]] = {} |
22 | | -"""A mapping of file paths to SQLMesh (context, model) tuples.""" |
23 | 25 |
|
24 | 26 | server = LanguageServer("sqlmesh_lsp", __version__) |
25 | 27 |
|
26 | | -_CACHE: t.Set[str] = set() |
27 | | -"""A cache of URIs for which we have already ensured a context exists.""" |
28 | | - |
29 | 28 |
|
30 | 29 | def ensure_context_for_document(document: TextDocument) -> TextDocument: |
31 | 30 | """Ensure that a context exists for the given document if applicable by searching for a config.py or config.yml file in the parent directories.""" |
32 | | - if document.uri in _CACHE: |
33 | | - return document |
34 | | - _CACHE.add(document.uri) |
| 31 | + # If the context is already loaded, return the document, if it is part of the same context |
| 32 | + if CONTEXT is not None: |
| 33 | + CONTEXT[0].reload() |
| 34 | + if document.uri in CONTEXT[1]: |
| 35 | + return document |
| 36 | + else: |
| 37 | + for model in CONTEXT[1]._models: |
| 38 | + path = model._path.resolve() |
| 39 | + if path == document.path: |
| 40 | + CONTEXT = (CONTEXT[0], CONTEXT[1] | {document.uri: (CONTEXT[0], model)}) |
| 41 | + return document |
| 42 | + for audit in CONTEXT[1]._audits: |
| 43 | + path = audit._path.resolve() |
| 44 | + if path == document.path: |
| 45 | + CONTEXT = (CONTEXT[0], CONTEXT[1] | {document.uri: (CONTEXT[0], audit)}) |
| 46 | + return document |
| 47 | + return document |
| 48 | + |
| 49 | + # If there is no context, load the context and then call this function again |
35 | 50 | path = Path(document.path).resolve() |
36 | 51 | if path.suffix not in (".sql", ".py"): |
37 | 52 | return document |
38 | | - initial_path = path |
39 | | - while path.parents: |
40 | | - if str(path) in CONTEXTS: |
41 | | - return document |
42 | | - path = path.parent |
43 | | - path = initial_path |
44 | 53 | loaded = False |
45 | 54 | while path.parents and not loaded: |
46 | 55 | for ext in ("py", "yml", "yaml"): |
47 | 56 | config_path = path / f"config.{ext}" |
48 | 57 | if config_path.exists(): |
49 | 58 | with suppress(Exception): |
50 | | - handle = Context(paths=[f"{path}"]) |
51 | | - CONTEXTS[str(path)] = handle |
52 | | - PATHS_TO_MODELS.update( |
53 | | - { |
54 | | - str(model._path.resolve()): (handle, model) |
55 | | - for model in handle.models.values() |
56 | | - } |
57 | | - ) |
| 59 | + handle = Context(paths=[path]) |
| 60 | + CONTEXT = (handle, {}) |
58 | 61 | server.show_message(f"Context loaded for: {path}") |
59 | 62 | loaded = True |
60 | | - break |
| 63 | + return ensure_context_for_document(document) |
61 | 64 | path = path.parent |
| 65 | + |
62 | 66 | return document |
63 | 67 |
|
64 | 68 |
|
65 | 69 | @server.feature(types.TEXT_DOCUMENT_FORMATTING) |
66 | 70 | def formatting( |
67 | 71 | ls: LanguageServer, params: types.DocumentFormattingParams |
68 | 72 | ) -> t.List[types.TextEdit]: |
69 | | - """Format the document based using SQLMesh format_model_expressions.""" |
| 73 | + """Format the document using SQLMesh format_model_expressions.""" |
70 | 74 | try: |
71 | 75 | document = ensure_context_for_document(ls.workspace.get_document(params.text_document.uri)) |
72 | | - context, _ = PATHS_TO_MODELS.get(document.path, (None, None)) |
| 76 | + context = CONTEXT |
73 | 77 | if context is None: |
74 | 78 | raise Exception(f"No context found for document: {document.path}") |
75 | 79 | context.format(paths=(Path(document.path),)) |
|
0 commit comments