Skip to content

Commit 7dadeed

Browse files
author
vishnu
committed
delete S3 tables fixed
1 parent 9f253c3 commit 7dadeed

1 file changed

Lines changed: 233 additions & 17 deletions

File tree

sqlmesh/core/engine_adapter/athena.py

Lines changed: 233 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,26 @@ def _get_data_objects(
138138
for row in df.itertuples()
139139
]
140140

141+
def table_exists(self, table_name: TableName) -> bool:
142+
from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key
143+
table = exp.to_table(table_name)
144+
data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name)
145+
146+
if data_object_cache_key in self._data_object_cache:
147+
logger.debug("Table existence cache hit: %s", data_object_cache_key)
148+
return self._data_object_cache[data_object_cache_key] is not None
149+
150+
try:
151+
# We don't use DESCRIBE because it fails with "Unsupported ddl with 2 catalogs"
152+
# for cross-catalog queries in Athena.
153+
# And since table_exists isn't run with the set_catalog decorator (which sets QueryExecutionContext),
154+
# we must fallback to a query that works with fully qualified names or
155+
# uses the information_schema/limit 0. A limit 0 select works with fully qualified names in Athena.
156+
self.execute(exp.select("1").from_(table).limit(0))
157+
return True
158+
except Exception:
159+
return False
160+
141161
def columns(
142162
self, table_name: TableName, include_pseudo_columns: bool = False
143163
) -> t.Dict[str, exp.DataType]:
@@ -152,11 +172,81 @@ def columns(
152172
.where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name))
153173
.order_by("ordinal_position")
154174
)
155-
result = self.fetchdf(query, quote_identifiers=True)
156-
return {
157-
str(r.column_name): exp.DataType.build(str(r.data_type))
158-
for r in result.itertuples(index=False)
159-
}
175+
176+
try:
177+
result = self.fetchdf(query, quote_identifiers=True)
178+
return {
179+
str(r.column_name): exp.DataType.build(str(r.data_type))
180+
for r in result.itertuples(index=False)
181+
}
182+
except Exception as e:
183+
# If information_schema query fails, we fallback to DESCRIBE.
184+
# But DESCRIBE with multiple catalogs fails in Athena, so we strip the catalog here
185+
# and rely on the set_current_catalog mechanism (applied at the EngineAdapter method level)
186+
# to set the catalog in the execution context.
187+
describe_table = table.copy()
188+
catalog = describe_table.catalog
189+
current_catalog = self.get_current_catalog()
190+
191+
if catalog and catalog != self._default_catalog:
192+
describe_table.set("catalog", None)
193+
if catalog != current_catalog:
194+
self.set_current_catalog(catalog)
195+
196+
try:
197+
self.execute(exp.Describe(this=describe_table, kind="TABLE"))
198+
199+
from sqlmesh.core.engine_adapter.base import _decoded_str
200+
import itertools
201+
describe_output = self.cursor.fetchall()
202+
return {
203+
# Note: MySQL returns the column type as bytes.
204+
column_name: exp.DataType.build(_decoded_str(column_type), dialect=self.dialect)
205+
for column_name, column_type, *_ in itertools.takewhile(
206+
lambda t: not t[0].startswith("#"),
207+
describe_output,
208+
)
209+
if column_name and column_name.strip() and column_type and column_type.strip()
210+
}
211+
finally:
212+
if catalog and catalog != self._default_catalog and current_catalog != catalog:
213+
if current_catalog is not None:
214+
self.set_current_catalog(current_catalog)
215+
216+
def _drop_object(
217+
self,
218+
name: TableName | SchemaName,
219+
exists: bool = True,
220+
kind: str = "TABLE",
221+
cascade: bool = False,
222+
**drop_args: t.Any,
223+
) -> None:
224+
if cascade and kind.upper() in self.SUPPORTED_DROP_CASCADE_OBJECT_KINDS:
225+
drop_args["cascade"] = cascade
226+
227+
target_table = exp.to_table(name).copy()
228+
is_schema = kind.upper() == "SCHEMA"
229+
catalog = target_table.db if is_schema else target_table.catalog
230+
231+
if catalog and catalog != self._default_catalog:
232+
if is_schema:
233+
target_table.set("db", None)
234+
else:
235+
target_table.set("catalog", None)
236+
237+
current_catalog = self.get_current_catalog()
238+
if current_catalog != catalog:
239+
self.set_current_catalog(catalog)
240+
241+
try:
242+
self.execute(exp.Drop(this=target_table, kind=kind, exists=exists, **drop_args))
243+
finally:
244+
if current_catalog is not None and current_catalog != catalog:
245+
self.set_current_catalog(current_catalog)
246+
else:
247+
self.execute(exp.Drop(this=target_table, kind=kind, exists=exists, **drop_args))
248+
249+
self._clear_data_object_cache(name)
160250

161251
def _create_schema(
162252
self,
@@ -167,18 +257,39 @@ def _create_schema(
167257
kind: str,
168258
) -> None:
169259
schema = to_schema(schema_name)
170-
if schema.catalog and schema.catalog != self._default_catalog:
171-
logger.info(
172-
"Skipping creation of schema '%s' because Athena does not support creating schemas in non-default catalogs.",
173-
schema.sql(dialect=self.dialect),
174-
)
175-
return
176260

177261
if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)):
178262
# don't add extra LocationProperty's if one already exists
179263
if not any(p for p in properties if isinstance(p, exp.LocationProperty)):
180264
properties.append(location)
181265

266+
if schema.catalog and schema.catalog != self._default_catalog:
267+
target_schema = schema.copy()
268+
catalog = target_schema.catalog
269+
target_schema.set("catalog", None)
270+
271+
current_catalog = self.get_current_catalog()
272+
if current_catalog != catalog:
273+
self.set_current_catalog(catalog)
274+
275+
try:
276+
self.execute(
277+
exp.Create(
278+
this=target_schema,
279+
kind=kind,
280+
exists=ignore_if_exists,
281+
properties=exp.Properties(expressions=properties),
282+
)
283+
)
284+
except Exception as e:
285+
if not warn_on_error:
286+
raise
287+
logger.warning("Failed to create %s '%s': %s", kind.lower(), schema_name, e)
288+
finally:
289+
if current_catalog is not None and current_catalog != catalog:
290+
self.set_current_catalog(current_catalog)
291+
return
292+
182293
return super()._create_schema(
183294
schema_name=schema_name,
184295
ignore_if_exists=ignore_if_exists,
@@ -187,6 +298,76 @@ def _create_schema(
187298
kind=kind,
188299
)
189300

301+
def _get_temp_table(
302+
self, table: TableName, table_only: bool = False, quoted: bool = True
303+
) -> exp.Table:
304+
"""
305+
Returns the name of the temp table that should be used for the given table name.
306+
"""
307+
from sqlmesh.utils import random_id
308+
309+
table = t.cast(exp.Table, exp.to_table(table).copy())
310+
311+
# AWS S3 Tables (and Athena generally) prefer or require table names to start with a letter.
312+
# S3 Tables specifically fail with: "The specified table name is not valid" if it starts with __temp_
313+
table.set(
314+
"this", exp.to_identifier(f"temp_{table.name}_{random_id(short=True)}", quoted=quoted)
315+
)
316+
317+
if table_only:
318+
table.set("db", None)
319+
table.set("catalog", None)
320+
321+
return table
322+
323+
def _create_table(
324+
self,
325+
table_name_or_schema: t.Union[exp.Schema, TableName],
326+
expression: t.Optional[exp.Expr],
327+
exists: bool = True,
328+
replace: bool = False,
329+
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
330+
table_description: t.Optional[str] = None,
331+
column_descriptions: t.Optional[t.Dict[str, str]] = None,
332+
table_kind: t.Optional[str] = None,
333+
track_rows_processed: bool = True,
334+
**kwargs: t.Any,
335+
) -> None:
336+
table: exp.Table
337+
if isinstance(table_name_or_schema, str):
338+
table = exp.to_table(table_name_or_schema)
339+
elif isinstance(table_name_or_schema, exp.Schema):
340+
table = table_name_or_schema.this
341+
else:
342+
table = table_name_or_schema
343+
344+
catalog = table.catalog
345+
current_catalog = self.get_current_catalog()
346+
347+
# For non-CTAS CREATE TABLE in a non-default catalog, the catalog is stripped by _build_create_table_exp.
348+
# We need to set the query execution context here.
349+
if not expression and catalog and catalog != self._default_catalog:
350+
if current_catalog != catalog:
351+
self.set_current_catalog(catalog)
352+
353+
try:
354+
super()._create_table(
355+
table_name_or_schema=table_name_or_schema,
356+
expression=expression,
357+
exists=exists,
358+
replace=replace,
359+
target_columns_to_types=target_columns_to_types,
360+
table_description=table_description,
361+
column_descriptions=column_descriptions,
362+
table_kind=table_kind,
363+
track_rows_processed=track_rows_processed,
364+
**kwargs,
365+
)
366+
finally:
367+
if not expression and catalog and catalog != self._default_catalog:
368+
if current_catalog is not None and current_catalog != catalog:
369+
self.set_current_catalog(current_catalog)
370+
190371
def _build_create_table_exp(
191372
self,
192373
table_name_or_schema: t.Union[exp.Schema, TableName],
@@ -240,8 +421,20 @@ def _build_create_table_exp(
240421
]
241422
table_name_or_schema.args["expressions"] = filtered_expressions
242423

424+
create_table = table_name_or_schema.copy()
425+
426+
# When creating a table without AS SELECT, Athena fails with "Unsupported ddl with 2 catalogs"
427+
# if a custom catalog like s3tablescatalog/supply is provided in the CREATE TABLE statement.
428+
# It requires the catalog to be provided via QueryExecutionContext instead.
429+
# The set_catalog decorator (which calls set_current_catalog) passes it to the QueryExecutionContext.
430+
# But we also need to strip it from the generated CREATE TABLE statement.
431+
# Note: We must strip the catalog from the table in the schema if table_name_or_schema is a schema.
432+
target_table = create_table.this if isinstance(create_table, exp.Schema) else create_table
433+
if not expression and target_table.catalog and target_table.catalog != self._default_catalog:
434+
target_table.set("catalog", None)
435+
243436
return exp.Create(
244-
this=table_name_or_schema,
437+
this=create_table,
245438
kind=table_kind or "TABLE",
246439
replace=replace,
247440
exists=exists,
@@ -446,11 +639,29 @@ def _query_table_type_or_raise(self, table: exp.Table) -> TableType:
446639
"""
447640
# Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here
448641
# This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks)
449-
for row in self.fetchall(f"SHOW TBLPROPERTIES {table.sql(dialect='hive', identify=True)}"):
450-
# This query returns a single column with values like 'EXTERNAL\tTRUE'
451-
row_lower = row[0].lower()
452-
if "external" in row_lower and "true" in row_lower:
453-
return "hive"
642+
target_table = table.copy()
643+
if target_table.catalog and target_table.catalog != self._default_catalog:
644+
catalog = target_table.catalog
645+
target_table.set("catalog", None)
646+
647+
current_catalog = self.get_current_catalog()
648+
if current_catalog != catalog:
649+
self.set_current_catalog(catalog)
650+
651+
try:
652+
for row in self.fetchall(f"SHOW TBLPROPERTIES {target_table.sql(dialect='hive', identify=True)}"):
653+
row_lower = row[0].lower()
654+
if "external" in row_lower and "true" in row_lower:
655+
return "hive"
656+
finally:
657+
if current_catalog is not None and current_catalog != catalog:
658+
self.set_current_catalog(current_catalog)
659+
else:
660+
for row in self.fetchall(f"SHOW TBLPROPERTIES {target_table.sql(dialect='hive', identify=True)}"):
661+
# This query returns a single column with values like 'EXTERNAL\tTRUE'
662+
row_lower = row[0].lower()
663+
if "external" in row_lower and "true" in row_lower:
664+
return "hive"
454665
return "iceberg"
455666

456667
def _is_hive_partitioned_table(self, table: exp.Table) -> bool:
@@ -700,5 +911,10 @@ def _boto3_client(self, name: str) -> t.Any:
700911
**conn._client_kwargs,
701912
) # type: ignore
702913

914+
def set_current_catalog(self, catalog: str) -> None:
915+
self.connection.catalog_name = catalog
916+
if hasattr(self.cursor, "_catalog_name"):
917+
self.cursor._catalog_name = catalog
918+
703919
def get_current_catalog(self) -> t.Optional[str]:
704920
return self.connection.catalog_name

0 commit comments

Comments
 (0)