@@ -61,7 +61,7 @@ def get_hints(
6161
6262
6363def _get_type_hints_for_select (
64- expression : exp .Select ,
64+ expression : exp .Expression ,
6565 dialect : str ,
6666 columns_to_types : t .Dict [str , exp .DataType ],
6767 start_line : int ,
@@ -113,28 +113,6 @@ def _get_type_hints_for_select(
113113 return hints
114114
115115
116- def _get_type_hints_for_expression (
117- expression : Expression ,
118- dialect : str ,
119- columns_to_types : t .Dict [str , exp .DataType ],
120- start_line : int ,
121- end_line : int ,
122- ) -> t .List [types .InlayHint ]:
123- if isinstance (expression , exp .Union ):
124- return _get_type_hints_for_expression (
125- expression .this , dialect , columns_to_types , start_line , end_line
126- ) + _get_type_hints_for_expression (
127- expression .expression , dialect , columns_to_types , start_line , end_line
128- )
129-
130- if isinstance (expression , exp .Select ):
131- return _get_type_hints_for_select (
132- expression , dialect , columns_to_types , start_line , end_line
133- )
134-
135- return []
136-
137-
138116def _get_type_hints_for_model_from_query (
139117 query : Expression ,
140118 dialect : str ,
@@ -150,8 +128,13 @@ def _get_type_hints_for_model_from_query(
150128 if not root :
151129 return []
152130
153- return _get_type_hints_for_expression (
154- root .expression , dialect , columns_to_types , start_line , end_line
155- )
131+ return [
132+ hint
133+ for q in query .walk (prune = lambda n : not isinstance (n , exp .SetOperation ))
134+ if isinstance (select := q .unnest (), exp .Select )
135+ for hint in _get_type_hints_for_select (
136+ q , dialect , columns_to_types , start_line , end_line
137+ )
138+ ]
156139 except Exception :
157140 return []
0 commit comments