@@ -57,11 +57,8 @@ def test_hints() -> None:
5757 start_line = 0 ,
5858 end_line = 9999 ,
5959 )
60- assert len (crbd_hints ) == 4
60+ assert len (crbd_hints ) == 1
6161 assert crbd_hints [0 ].label == "::INT"
62- assert crbd_hints [1 ].label == "::DOUBLE"
63- assert crbd_hints [2 ].label == "::INT"
64- assert crbd_hints [3 ].label == "::DATE"
6562
6663
6764@pytest .mark .fast
@@ -105,3 +102,102 @@ def test_complex_hints() -> None:
105102 assert len (result ) == 2
106103 assert result [0 ].label == "::VARCHAR(100)"
107104 assert result [1 ].label == "::STRUCT<INT, STRUCT<TEXT, INT[]>>"
105+
106+
107+ @pytest .mark .fast
108+ def test_simple_cast_hints () -> None :
109+ """Don't add type hints if the expression is already a cast"""
110+ query = parse_one ("SELECT a::INT, CAST(b AS DATE), c FROM d" , dialect = "postgres" )
111+
112+ result = _get_type_hints_for_model_from_query (
113+ query = query ,
114+ dialect = "postgres" ,
115+ columns_to_types = {
116+ "a" : exp .DataType .build ("INT" ),
117+ "b" : exp .DataType .build ("DATE" ),
118+ "c" : exp .DataType .build ("TEXT" ),
119+ },
120+ start_line = 0 ,
121+ end_line = 1 ,
122+ )
123+
124+ assert len (result ) == 1
125+ assert result [0 ].label == "::TEXT"
126+
127+
128+ @pytest .mark .fast
129+ def test_alias_cast_hints () -> None :
130+ """Don't add type hints if the expression is already a cast"""
131+ query = parse_one (
132+ "SELECT raw_a::INT AS a, CAST(raw_b AS DATE) AS b, c FROM d" , dialect = "postgres"
133+ )
134+
135+ result = _get_type_hints_for_model_from_query (
136+ query = query ,
137+ dialect = "postgres" ,
138+ columns_to_types = {
139+ "a" : exp .DataType .build ("INT" ),
140+ "b" : exp .DataType .build ("DATE" ),
141+ "c" : exp .DataType .build ("TEXT" ),
142+ },
143+ start_line = 0 ,
144+ end_line = 1 ,
145+ )
146+
147+ assert len (result ) == 1
148+ assert result [0 ].label == "::TEXT"
149+
150+
151+ @pytest .mark .fast
152+ def test_simple_cte_hints () -> None :
153+ """Don't add type hints if the expression is already a cast"""
154+ query = parse_one ("WITH t AS (SELECT a FROM b) SELECT a AS c FROM t" , dialect = "postgres" )
155+
156+ result = _get_type_hints_for_model_from_query (
157+ query = query ,
158+ dialect = "postgres" ,
159+ columns_to_types = {
160+ "c" : exp .DataType .build ("INT" ),
161+ },
162+ start_line = 0 ,
163+ end_line = 1 ,
164+ )
165+
166+ assert len (result ) == 1
167+ assert result [0 ].label == "::INT"
168+
169+
170+ @pytest .mark .fast
171+ def test_cte_with_union_hints () -> None :
172+ """Don't add type hints if the expression is already a cast"""
173+ query = parse_one (
174+ """WITH x AS (SELECT a FROM t),
175+ y AS (SELECT b FROM t),
176+ z AS (SELECT c FROM t)
177+ SELECT a AS d FROM x
178+ UNION
179+ SELECT b AS e FROM y
180+ UNION
181+ SELECT c AS f FROM z""" ,
182+ dialect = "postgres" ,
183+ )
184+
185+ result = _get_type_hints_for_model_from_query (
186+ query = query ,
187+ dialect = "postgres" ,
188+ columns_to_types = {
189+ "a" : exp .DataType .build ("TEXT" ),
190+ "b" : exp .DataType .build ("DATE" ),
191+ "c" : exp .DataType .build ("INT" ),
192+ "d" : exp .DataType .build ("TEXT" ),
193+ "e" : exp .DataType .build ("DATE" ),
194+ "f" : exp .DataType .build ("INT" ),
195+ },
196+ start_line = 0 ,
197+ end_line = 9999 ,
198+ )
199+
200+ assert len (result ) == 3
201+ assert result [0 ].label == "::TEXT"
202+ assert result [1 ].label == "::DATE"
203+ assert result [2 ].label == "::INT"
0 commit comments