|
11 | 11 |
|
12 | 12 | class Reference(PydanticModel): |
13 | 13 | """ |
14 | | - A reference to a model. |
| 14 | + A reference to a model or CTE. |
15 | 15 |
|
16 | 16 | Attributes: |
17 | 17 | range: The range of the reference in the source file |
18 | | - uri: The uri of the referenced model |
19 | | - description: The description of the referenced model |
| 18 | + uri: The uri of the referenced model or file |
| 19 | + description: The description of the referenced model or CTE |
| 20 | + target_range: The range of the definition for go-to-definition (optional, used for CTEs) |
20 | 21 | """ |
21 | 22 |
|
22 | 23 | range: Range |
23 | 24 | uri: str |
24 | 25 | description: t.Optional[str] = None |
| 26 | + target_range: t.Optional[Range] = None |
25 | 27 |
|
26 | 28 |
|
27 | 29 | def by_position(position: Position) -> t.Callable[[Reference], bool]: |
@@ -87,6 +89,7 @@ def get_model_definitions_for_a_path( |
87 | 89 | - Need to normalize it before matching |
88 | 90 | - Try get_model before normalization |
89 | 91 | - Match to models that the model refers to |
| 92 | + - Also find CTE references within the query |
90 | 93 | """ |
91 | 94 | path = document_uri.to_path() |
92 | 95 | if path.suffix != ".sql": |
@@ -127,13 +130,49 @@ def get_model_definitions_for_a_path( |
127 | 130 |
|
128 | 131 | # Get SQL query and find all table references |
129 | 132 | tables = list(query.find_all(exp.Table)) |
130 | | - if len(tables) == 0: |
131 | | - return [] |
132 | 133 |
|
133 | 134 | with open(file_path, "r", encoding="utf-8") as file: |
134 | 135 | read_file = file.readlines() |
135 | 136 |
|
| 137 | + # Build a map of CTE names to their definitions for CTE go-to-definition |
| 138 | + cte_definitions = {} |
| 139 | + with_clause = query.find(exp.With) |
| 140 | + if with_clause: |
| 141 | + for cte in with_clause.expressions: |
| 142 | + if isinstance(cte.alias, str): |
| 143 | + cte_definitions[cte.alias] = cte |
| 144 | + |
136 | 145 | for table in tables: |
| 146 | + table_name = table.name |
| 147 | + |
| 148 | + # Check if this table reference is a CTE |
| 149 | + if table_name in cte_definitions: |
| 150 | + try: |
| 151 | + # This is a CTE reference - create a reference to the CTE definition |
| 152 | + cte_def = cte_definitions[table_name] |
| 153 | + args = cte_def.args["alias"] |
| 154 | + if args and isinstance(args, exp.TableAlias): |
| 155 | + identifier = args.this |
| 156 | + if isinstance(identifier, exp.Identifier): |
| 157 | + meta = identifier.meta |
| 158 | + |
| 159 | + table_meta = TokenPositionDetails.from_meta(meta) |
| 160 | + target_range = _range_from_token_position_details(table_meta, read_file) |
| 161 | + table_meta = TokenPositionDetails.from_meta(table.this.meta) |
| 162 | + table_range = _range_from_token_position_details(table_meta, read_file) |
| 163 | + |
| 164 | + references.append( |
| 165 | + Reference( |
| 166 | + uri=document_uri.value, # Same file |
| 167 | + range=table_range, |
| 168 | + target_range=target_range, |
| 169 | + ) |
| 170 | + ) |
| 171 | + except Exception: |
| 172 | + pass |
| 173 | + continue |
| 174 | + |
| 175 | + # For non-CTE tables, process as before (external model references) |
137 | 176 | # Normalize the table reference |
138 | 177 | unaliased = table.copy() |
139 | 178 | if unaliased.args.get("alias") is not None: |
|
0 commit comments