|
5 | 5 | from sqlmesh.core.model.definition import SqlModel |
6 | 6 | from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget |
7 | 7 | from sqlglot import exp |
| 8 | +from sqlglot.optimizer.scope import build_scope |
8 | 9 | from sqlmesh.lsp.uri import URI |
9 | 10 | from sqlmesh.utils.pydantic import PydanticModel |
10 | 11 |
|
11 | 12 |
|
12 | 13 | class Reference(PydanticModel): |
13 | 14 | """ |
14 | | - A reference to a model. |
| 15 | + A reference to a model or CTE. |
15 | 16 |
|
16 | 17 | Attributes: |
17 | 18 | range: The range of the reference in the source file |
18 | | - uri: The uri of the referenced model |
19 | | - description: The description of the referenced model |
| 19 | + uri: The uri of the referenced model or file |
| 20 | + description: The description of the referenced model or CTE |
| 21 | + target_range: The range of the definition for go-to-definition (optional, used for CTEs) |
20 | 22 | """ |
21 | 23 |
|
22 | 24 | range: Range |
23 | 25 | uri: str |
24 | 26 | description: t.Optional[str] = None |
| 27 | + target_range: t.Optional[Range] = None |
25 | 28 |
|
26 | 29 |
|
27 | 30 | def by_position(position: Position) -> t.Callable[[Reference], bool]: |
@@ -87,6 +90,7 @@ def get_model_definitions_for_a_path( |
87 | 90 | - Need to normalize it before matching |
88 | 91 | - Try get_model before normalization |
89 | 92 | - Match to models that the model refers to |
| 93 | + - Also find CTE references within the query |
90 | 94 | """ |
91 | 95 | path = document_uri.to_path() |
92 | 96 | if path.suffix != ".sql": |
@@ -125,64 +129,121 @@ def get_model_definitions_for_a_path( |
125 | 129 | # Find all possible references |
126 | 130 | references = [] |
127 | 131 |
|
128 | | - # Get SQL query and find all table references |
129 | | - tables = list(query.find_all(exp.Table)) |
130 | | - if len(tables) == 0: |
131 | | - return [] |
132 | | - |
133 | 132 | with open(file_path, "r", encoding="utf-8") as file: |
134 | 133 | read_file = file.readlines() |
135 | 134 |
|
136 | | - for table in tables: |
137 | | - # Normalize the table reference |
138 | | - unaliased = table.copy() |
139 | | - if unaliased.args.get("alias") is not None: |
140 | | - unaliased.set("alias", None) |
141 | | - reference_name = unaliased.sql(dialect=dialect) |
142 | | - try: |
143 | | - normalized_reference_name = normalize_model_name( |
144 | | - reference_name, |
145 | | - default_catalog=lint_context.context.default_catalog, |
146 | | - dialect=dialect, |
147 | | - ) |
148 | | - if normalized_reference_name not in depends_on: |
149 | | - continue |
150 | | - except Exception: |
151 | | - # Skip references that cannot be normalized |
152 | | - continue |
153 | | - |
154 | | - # Get the referenced model uri |
155 | | - referenced_model = lint_context.context.get_model( |
156 | | - model_or_snapshot=normalized_reference_name, raise_if_missing=False |
157 | | - ) |
158 | | - if referenced_model is None: |
159 | | - continue |
160 | | - referenced_model_path = referenced_model._path |
161 | | - # Check whether the path exists |
162 | | - if not referenced_model_path.is_file(): |
163 | | - continue |
164 | | - referenced_model_uri = URI.from_path(referenced_model_path) |
165 | | - |
166 | | - # Extract metadata for positioning |
167 | | - table_meta = TokenPositionDetails.from_meta(table.this.meta) |
168 | | - table_range = _range_from_token_position_details(table_meta, read_file) |
169 | | - start_pos = table_range.start |
170 | | - end_pos = table_range.end |
171 | | - |
172 | | - # If there's a catalog or database qualifier, adjust the start position |
173 | | - catalog_or_db = table.args.get("catalog") or table.args.get("db") |
174 | | - if catalog_or_db is not None: |
175 | | - catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
176 | | - catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file) |
177 | | - start_pos = catalog_or_db_range.start |
178 | | - |
179 | | - references.append( |
180 | | - Reference( |
181 | | - uri=referenced_model_uri.value, |
182 | | - range=Range(start=start_pos, end=end_pos), |
183 | | - description=referenced_model.description, |
184 | | - ) |
185 | | - ) |
| 135 | + # Build scope tree to properly handle nested CTEs |
| 136 | + root_scope = build_scope(query) |
| 137 | + |
| 138 | + if root_scope: |
| 139 | + # Traverse all scopes to find CTE definitions and table references |
| 140 | + for scope in root_scope.traverse(): |
| 141 | + # Build a map of CTE names to their definitions within this scope |
| 142 | + cte_definitions = {} |
| 143 | + |
| 144 | + # For CTEs defined in this scope |
| 145 | + for cte in scope.ctes: |
| 146 | + if cte.alias: |
| 147 | + cte_definitions[cte.alias] = cte |
| 148 | + |
| 149 | + # Also include CTEs from parent scopes (for references inside nested CTEs) |
| 150 | + parent = scope.parent |
| 151 | + while parent: |
| 152 | + for cte in parent.ctes: |
| 153 | + if cte.alias and cte.alias not in cte_definitions: |
| 154 | + cte_definitions[cte.alias] = cte |
| 155 | + parent = parent.parent |
| 156 | + |
| 157 | + # Get all table references in this scope |
| 158 | + tables = list(scope.find_all(exp.Table)) |
| 159 | + |
| 160 | + for table in tables: |
| 161 | + table_name = table.name |
| 162 | + |
| 163 | + # Check if this table reference is a CTE in the current scope |
| 164 | + if table_name in cte_definitions: |
| 165 | + try: |
| 166 | + # This is a CTE reference - create a reference to the CTE definition |
| 167 | + cte_def = cte_definitions[table_name] |
| 168 | + args = cte_def.args["alias"] |
| 169 | + if args and isinstance(args, exp.TableAlias): |
| 170 | + identifier = args.this |
| 171 | + if isinstance(identifier, exp.Identifier): |
| 172 | + meta = identifier.meta |
| 173 | + |
| 174 | + table_meta_obj = TokenPositionDetails.from_meta(meta) |
| 175 | + target_range = _range_from_token_position_details( |
| 176 | + table_meta_obj, read_file |
| 177 | + ) |
| 178 | + |
| 179 | + table_meta_obj = TokenPositionDetails.from_meta(table.this.meta) |
| 180 | + table_range = _range_from_token_position_details( |
| 181 | + table_meta_obj, read_file |
| 182 | + ) |
| 183 | + |
| 184 | + references.append( |
| 185 | + Reference( |
| 186 | + uri=document_uri.value, # Same file |
| 187 | + range=table_range, |
| 188 | + target_range=target_range, |
| 189 | + ) |
| 190 | + ) |
| 191 | + except Exception: |
| 192 | + pass |
| 193 | + continue |
| 194 | + |
| 195 | + # For non-CTE tables, process as before (external model references) |
| 196 | + # Normalize the table reference |
| 197 | + unaliased = table.copy() |
| 198 | + if unaliased.args.get("alias") is not None: |
| 199 | + unaliased.set("alias", None) |
| 200 | + reference_name = unaliased.sql(dialect=dialect) |
| 201 | + try: |
| 202 | + normalized_reference_name = normalize_model_name( |
| 203 | + reference_name, |
| 204 | + default_catalog=lint_context.context.default_catalog, |
| 205 | + dialect=dialect, |
| 206 | + ) |
| 207 | + if normalized_reference_name not in depends_on: |
| 208 | + continue |
| 209 | + except Exception: |
| 210 | + # Skip references that cannot be normalized |
| 211 | + continue |
| 212 | + |
| 213 | + # Get the referenced model uri |
| 214 | + referenced_model = lint_context.context.get_model( |
| 215 | + model_or_snapshot=normalized_reference_name, raise_if_missing=False |
| 216 | + ) |
| 217 | + if referenced_model is None: |
| 218 | + continue |
| 219 | + referenced_model_path = referenced_model._path |
| 220 | + # Check whether the path exists |
| 221 | + if not referenced_model_path.is_file(): |
| 222 | + continue |
| 223 | + referenced_model_uri = URI.from_path(referenced_model_path) |
| 224 | + |
| 225 | + # Extract metadata for positioning |
| 226 | + table_meta = TokenPositionDetails.from_meta(table.this.meta) |
| 227 | + table_range = _range_from_token_position_details(table_meta, read_file) |
| 228 | + start_pos = table_range.start |
| 229 | + end_pos = table_range.end |
| 230 | + |
| 231 | + # If there's a catalog or database qualifier, adjust the start position |
| 232 | + catalog_or_db = table.args.get("catalog") or table.args.get("db") |
| 233 | + if catalog_or_db is not None: |
| 234 | + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
| 235 | + catalog_or_db_range = _range_from_token_position_details( |
| 236 | + catalog_or_db_meta, read_file |
| 237 | + ) |
| 238 | + start_pos = catalog_or_db_range.start |
| 239 | + |
| 240 | + references.append( |
| 241 | + Reference( |
| 242 | + uri=referenced_model_uri.value, |
| 243 | + range=Range(start=start_pos, end=end_pos), |
| 244 | + description=referenced_model.description, |
| 245 | + ) |
| 246 | + ) |
186 | 247 |
|
187 | 248 | return references |
188 | 249 |
|
|
0 commit comments