@@ -138,6 +138,26 @@ def _get_data_objects(
138138 for row in df .itertuples ()
139139 ]
140140
141+ def table_exists (self , table_name : TableName ) -> bool :
142+ from sqlmesh .core .engine_adapter .base import _get_data_object_cache_key
143+ table = exp .to_table (table_name )
144+ data_object_cache_key = _get_data_object_cache_key (table .catalog , table .db , table .name )
145+
146+ if data_object_cache_key in self ._data_object_cache :
147+ logger .debug ("Table existence cache hit: %s" , data_object_cache_key )
148+ return self ._data_object_cache [data_object_cache_key ] is not None
149+
150+ try :
151+ # We don't use DESCRIBE because it fails with "Unsupported ddl with 2 catalogs"
152+ # for cross-catalog queries in Athena.
153+ # And since table_exists isn't run with the set_catalog decorator (which sets QueryExecutionContext),
154+ # we must fallback to a query that works with fully qualified names or
155+ # uses the information_schema/limit 0. A limit 0 select works with fully qualified names in Athena.
156+ self .execute (exp .select ("1" ).from_ (table ).limit (0 ))
157+ return True
158+ except Exception :
159+ return False
160+
141161 def columns (
142162 self , table_name : TableName , include_pseudo_columns : bool = False
143163 ) -> t .Dict [str , exp .DataType ]:
@@ -152,11 +172,81 @@ def columns(
152172 .where (exp .column ("table_schema" ).eq (table .db ), exp .column ("table_name" ).eq (table .name ))
153173 .order_by ("ordinal_position" )
154174 )
155- result = self .fetchdf (query , quote_identifiers = True )
156- return {
157- str (r .column_name ): exp .DataType .build (str (r .data_type ))
158- for r in result .itertuples (index = False )
159- }
175+
176+ try :
177+ result = self .fetchdf (query , quote_identifiers = True )
178+ return {
179+ str (r .column_name ): exp .DataType .build (str (r .data_type ))
180+ for r in result .itertuples (index = False )
181+ }
182+ except Exception as e :
183+ # If information_schema query fails, we fallback to DESCRIBE.
184+ # But DESCRIBE with multiple catalogs fails in Athena, so we strip the catalog here
185+ # and rely on the set_current_catalog mechanism (applied at the EngineAdapter method level)
186+ # to set the catalog in the execution context.
187+ describe_table = table .copy ()
188+ catalog = describe_table .catalog
189+ current_catalog = self .get_current_catalog ()
190+
191+ if catalog and catalog != self ._default_catalog :
192+ describe_table .set ("catalog" , None )
193+ if catalog != current_catalog :
194+ self .set_current_catalog (catalog )
195+
196+ try :
197+ self .execute (exp .Describe (this = describe_table , kind = "TABLE" ))
198+
199+ from sqlmesh .core .engine_adapter .base import _decoded_str
200+ import itertools
201+ describe_output = self .cursor .fetchall ()
202+ return {
203+ # Note: MySQL returns the column type as bytes.
204+ column_name : exp .DataType .build (_decoded_str (column_type ), dialect = self .dialect )
205+ for column_name , column_type , * _ in itertools .takewhile (
206+ lambda t : not t [0 ].startswith ("#" ),
207+ describe_output ,
208+ )
209+ if column_name and column_name .strip () and column_type and column_type .strip ()
210+ }
211+ finally :
212+ if catalog and catalog != self ._default_catalog and current_catalog != catalog :
213+ if current_catalog is not None :
214+ self .set_current_catalog (current_catalog )
215+
216+ def _drop_object (
217+ self ,
218+ name : TableName | SchemaName ,
219+ exists : bool = True ,
220+ kind : str = "TABLE" ,
221+ cascade : bool = False ,
222+ ** drop_args : t .Any ,
223+ ) -> None :
224+ if cascade and kind .upper () in self .SUPPORTED_DROP_CASCADE_OBJECT_KINDS :
225+ drop_args ["cascade" ] = cascade
226+
227+ target_table = exp .to_table (name ).copy ()
228+ is_schema = kind .upper () == "SCHEMA"
229+ catalog = target_table .db if is_schema else target_table .catalog
230+
231+ if catalog and catalog != self ._default_catalog :
232+ if is_schema :
233+ target_table .set ("db" , None )
234+ else :
235+ target_table .set ("catalog" , None )
236+
237+ current_catalog = self .get_current_catalog ()
238+ if current_catalog != catalog :
239+ self .set_current_catalog (catalog )
240+
241+ try :
242+ self .execute (exp .Drop (this = target_table , kind = kind , exists = exists , ** drop_args ))
243+ finally :
244+ if current_catalog is not None and current_catalog != catalog :
245+ self .set_current_catalog (current_catalog )
246+ else :
247+ self .execute (exp .Drop (this = target_table , kind = kind , exists = exists , ** drop_args ))
248+
249+ self ._clear_data_object_cache (name )
160250
161251 def _create_schema (
162252 self ,
@@ -167,18 +257,39 @@ def _create_schema(
167257 kind : str ,
168258 ) -> None :
169259 schema = to_schema (schema_name )
170- if schema .catalog and schema .catalog != self ._default_catalog :
171- logger .info (
172- "Skipping creation of schema '%s' because Athena does not support creating schemas in non-default catalogs." ,
173- schema .sql (dialect = self .dialect ),
174- )
175- return
176260
177261 if location := self ._table_location (table_properties = None , table = exp .to_table (schema_name )):
178262 # don't add extra LocationProperty's if one already exists
179263 if not any (p for p in properties if isinstance (p , exp .LocationProperty )):
180264 properties .append (location )
181265
266+ if schema .catalog and schema .catalog != self ._default_catalog :
267+ target_schema = schema .copy ()
268+ catalog = target_schema .catalog
269+ target_schema .set ("catalog" , None )
270+
271+ current_catalog = self .get_current_catalog ()
272+ if current_catalog != catalog :
273+ self .set_current_catalog (catalog )
274+
275+ try :
276+ self .execute (
277+ exp .Create (
278+ this = target_schema ,
279+ kind = kind ,
280+ exists = ignore_if_exists ,
281+ properties = exp .Properties (expressions = properties ),
282+ )
283+ )
284+ except Exception as e :
285+ if not warn_on_error :
286+ raise
287+ logger .warning ("Failed to create %s '%s': %s" , kind .lower (), schema_name , e )
288+ finally :
289+ if current_catalog is not None and current_catalog != catalog :
290+ self .set_current_catalog (current_catalog )
291+ return
292+
182293 return super ()._create_schema (
183294 schema_name = schema_name ,
184295 ignore_if_exists = ignore_if_exists ,
@@ -187,6 +298,76 @@ def _create_schema(
187298 kind = kind ,
188299 )
189300
301+ def _get_temp_table (
302+ self , table : TableName , table_only : bool = False , quoted : bool = True
303+ ) -> exp .Table :
304+ """
305+ Returns the name of the temp table that should be used for the given table name.
306+ """
307+ from sqlmesh .utils import random_id
308+
309+ table = t .cast (exp .Table , exp .to_table (table ).copy ())
310+
311+ # AWS S3 Tables (and Athena generally) prefer or require table names to start with a letter.
312+ # S3 Tables specifically fail with: "The specified table name is not valid" if it starts with __temp_
313+ table .set (
314+ "this" , exp .to_identifier (f"temp_{ table .name } _{ random_id (short = True )} " , quoted = quoted )
315+ )
316+
317+ if table_only :
318+ table .set ("db" , None )
319+ table .set ("catalog" , None )
320+
321+ return table
322+
323+ def _create_table (
324+ self ,
325+ table_name_or_schema : t .Union [exp .Schema , TableName ],
326+ expression : t .Optional [exp .Expr ],
327+ exists : bool = True ,
328+ replace : bool = False ,
329+ target_columns_to_types : t .Optional [t .Dict [str , exp .DataType ]] = None ,
330+ table_description : t .Optional [str ] = None ,
331+ column_descriptions : t .Optional [t .Dict [str , str ]] = None ,
332+ table_kind : t .Optional [str ] = None ,
333+ track_rows_processed : bool = True ,
334+ ** kwargs : t .Any ,
335+ ) -> None :
336+ table : exp .Table
337+ if isinstance (table_name_or_schema , str ):
338+ table = exp .to_table (table_name_or_schema )
339+ elif isinstance (table_name_or_schema , exp .Schema ):
340+ table = table_name_or_schema .this
341+ else :
342+ table = table_name_or_schema
343+
344+ catalog = table .catalog
345+ current_catalog = self .get_current_catalog ()
346+
347+ # For non-CTAS CREATE TABLE in a non-default catalog, the catalog is stripped by _build_create_table_exp.
348+ # We need to set the query execution context here.
349+ if not expression and catalog and catalog != self ._default_catalog :
350+ if current_catalog != catalog :
351+ self .set_current_catalog (catalog )
352+
353+ try :
354+ super ()._create_table (
355+ table_name_or_schema = table_name_or_schema ,
356+ expression = expression ,
357+ exists = exists ,
358+ replace = replace ,
359+ target_columns_to_types = target_columns_to_types ,
360+ table_description = table_description ,
361+ column_descriptions = column_descriptions ,
362+ table_kind = table_kind ,
363+ track_rows_processed = track_rows_processed ,
364+ ** kwargs ,
365+ )
366+ finally :
367+ if not expression and catalog and catalog != self ._default_catalog :
368+ if current_catalog is not None and current_catalog != catalog :
369+ self .set_current_catalog (current_catalog )
370+
190371 def _build_create_table_exp (
191372 self ,
192373 table_name_or_schema : t .Union [exp .Schema , TableName ],
@@ -240,8 +421,20 @@ def _build_create_table_exp(
240421 ]
241422 table_name_or_schema .args ["expressions" ] = filtered_expressions
242423
424+ create_table = table_name_or_schema .copy ()
425+
426+ # When creating a table without AS SELECT, Athena fails with "Unsupported ddl with 2 catalogs"
427+ # if a custom catalog like s3tablescatalog/supply is provided in the CREATE TABLE statement.
428+ # It requires the catalog to be provided via QueryExecutionContext instead.
429+ # The set_catalog decorator (which calls set_current_catalog) passes it to the QueryExecutionContext.
430+ # But we also need to strip it from the generated CREATE TABLE statement.
431+ # Note: We must strip the catalog from the table in the schema if table_name_or_schema is a schema.
432+ target_table = create_table .this if isinstance (create_table , exp .Schema ) else create_table
433+ if not expression and target_table .catalog and target_table .catalog != self ._default_catalog :
434+ target_table .set ("catalog" , None )
435+
243436 return exp .Create (
244- this = table_name_or_schema ,
437+ this = create_table ,
245438 kind = table_kind or "TABLE" ,
246439 replace = replace ,
247440 exists = exists ,
@@ -446,11 +639,29 @@ def _query_table_type_or_raise(self, table: exp.Table) -> TableType:
446639 """
447640 # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here
448641 # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks)
449- for row in self .fetchall (f"SHOW TBLPROPERTIES { table .sql (dialect = 'hive' , identify = True )} " ):
450- # This query returns a single column with values like 'EXTERNAL\tTRUE'
451- row_lower = row [0 ].lower ()
452- if "external" in row_lower and "true" in row_lower :
453- return "hive"
642+ target_table = table .copy ()
643+ if target_table .catalog and target_table .catalog != self ._default_catalog :
644+ catalog = target_table .catalog
645+ target_table .set ("catalog" , None )
646+
647+ current_catalog = self .get_current_catalog ()
648+ if current_catalog != catalog :
649+ self .set_current_catalog (catalog )
650+
651+ try :
652+ for row in self .fetchall (f"SHOW TBLPROPERTIES { target_table .sql (dialect = 'hive' , identify = True )} " ):
653+ row_lower = row [0 ].lower ()
654+ if "external" in row_lower and "true" in row_lower :
655+ return "hive"
656+ finally :
657+ if current_catalog is not None and current_catalog != catalog :
658+ self .set_current_catalog (current_catalog )
659+ else :
660+ for row in self .fetchall (f"SHOW TBLPROPERTIES { target_table .sql (dialect = 'hive' , identify = True )} " ):
661+ # This query returns a single column with values like 'EXTERNAL\tTRUE'
662+ row_lower = row [0 ].lower ()
663+ if "external" in row_lower and "true" in row_lower :
664+ return "hive"
454665 return "iceberg"
455666
456667 def _is_hive_partitioned_table (self , table : exp .Table ) -> bool :
@@ -700,5 +911,10 @@ def _boto3_client(self, name: str) -> t.Any:
700911 ** conn ._client_kwargs ,
701912 ) # type: ignore
702913
914+ def set_current_catalog (self , catalog : str ) -> None :
915+ self .connection .catalog_name = catalog
916+ if hasattr (self .cursor , "_catalog_name" ):
917+ self .cursor ._catalog_name = catalog
918+
703919 def get_current_catalog (self ) -> t .Optional [str ]:
704920 return self .connection .catalog_name
0 commit comments