-
Notifications
You must be signed in to change notification settings - Fork 380
Expand file tree
/
Copy pathtrino.py
More file actions
439 lines (396 loc) · 17.6 KB
/
trino.py
File metadata and controls
439 lines (396 loc) · 17.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
from __future__ import annotations
import contextlib
import re
import typing as t
from functools import lru_cache
from sqlglot import exp
from sqlglot.helper import seq_get
from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_result
from sqlmesh.core.dialect import schema_, to_schema
from sqlmesh.core.engine_adapter.mixins import (
GetCurrentCatalogFromFunctionMixin,
HiveMetastoreTablePropertiesMixin,
PandasNativeFetchDFSupportMixin,
RowDiffMixin,
)
from sqlmesh.core.engine_adapter.shared import (
CatalogSupport,
CommentCreationTable,
CommentCreationView,
DataObject,
DataObjectType,
InsertOverwriteStrategy,
SourceQuery,
set_catalog,
)
from sqlmesh.utils import get_source_columns_to_types
from sqlmesh.utils.errors import SQLMeshError
from sqlmesh.utils.date import TimeLike
if t.TYPE_CHECKING:
from sqlmesh.core._typing import SchemaName, SessionProperties, TableName
from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF
CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {"iceberg", "delta_lake"}
@set_catalog()
class TrinoEngineAdapter(
PandasNativeFetchDFSupportMixin,
HiveMetastoreTablePropertiesMixin,
GetCurrentCatalogFromFunctionMixin,
RowDiffMixin,
):
DIALECT = "trino"
INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INTO_IS_OVERWRITE
# Trino does technically support transactions but it doesn't work correctly with partition overwrite so we
# disable transactions. If we need to get them enabled again then we would need to disable auto commit on the
# connector and then figure out how to get insert/overwrite to work correctly without it.
SUPPORTS_TRANSACTIONS = False
CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog")
COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS
COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY
SUPPORTS_REPLACE_TABLE = False
SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"]
DEFAULT_CATALOG_TYPE = "hive"
QUOTE_IDENTIFIERS_IN_VIEWS = False
SUPPORTS_QUERY_EXECUTION_TRACKING = True
SCHEMA_DIFFER_KWARGS = {
"parameterized_type_defaults": {
# default decimal precision varies across backends
exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)],
exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)],
exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(3,)],
},
}
# some catalogs support microsecond (precision 6) but it has to be specifically enabled (Hive) or just isnt available (Delta / TIMESTAMP WITH TIME ZONE)
# and even if you have a TIMESTAMP(6) the date formatting functions still only support millisecond precision
MAX_TIMESTAMP_PRECISION = 3
@property
def schema_location_mapping(self) -> t.Optional[dict[re.Pattern, str]]:
return self._extra_config.get("schema_location_mapping")
@property
def catalog_support(self) -> CatalogSupport:
return CatalogSupport.FULL_SUPPORT
def set_current_catalog(self, catalog: str) -> None:
"""Sets the catalog name of the current connection."""
self.execute(exp.Use(this=schema_(db="information_schema", catalog=catalog)))
@lru_cache()
def get_catalog_type(self, catalog: t.Optional[str]) -> str:
row: t.Tuple = tuple()
if catalog:
row = (
self.fetchone(
f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'"
)
or ()
)
return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE
@contextlib.contextmanager
def session(self, properties: SessionProperties) -> t.Iterator[None]:
authorization = properties.get("authorization")
if not authorization:
yield
return
if not isinstance(authorization, exp.Expression):
authorization = exp.Literal.string(authorization)
if not authorization.is_string:
raise SQLMeshError(
"Invalid value for `session_properties.authorization`. Must be a string literal."
)
authorization_sql = authorization.sql(dialect=self.dialect)
self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}")
try:
yield
finally:
self.execute(f"RESET SESSION AUTHORIZATION")
def replace_query(
self,
table_name: TableName,
query_or_df: QueryOrDF,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
source_columns: t.Optional[t.List[str]] = None,
supports_replace_table_override: t.Optional[bool] = None,
**kwargs: t.Any,
) -> None:
catalog_type = self.get_catalog_type(self.get_catalog_type_from_table(table_name))
# User may have a custom catalog type name so we are assuming they keep the catalog type still in the name
# Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table
supports_replace_table_override = None
for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE:
if replace_table_catalog_type in catalog_type:
supports_replace_table_override = True
break
super().replace_query(
table_name=table_name,
query_or_df=query_or_df,
target_columns_to_types=target_columns_to_types,
table_description=table_description,
column_descriptions=column_descriptions,
source_columns=source_columns,
supports_replace_table_override=supports_replace_table_override,
**kwargs,
)
def _insert_overwrite_by_condition(
self,
table_name: TableName,
source_queries: t.List[SourceQuery],
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
where: t.Optional[exp.Condition] = None,
insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None,
**kwargs: t.Any,
) -> None:
catalog = exp.to_table(table_name).catalog or self.get_current_catalog()
if where and self.get_catalog_type(catalog) == "hive":
# These session properties are only valid for the Trino Hive connector
# Attempting to set them on an Iceberg catalog will throw an error:
# "Session property 'catalog.insert_existing_partitions_behavior' does not exist"
self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='OVERWRITE'")
super()._insert_overwrite_by_condition(
table_name, source_queries, target_columns_to_types, where
)
self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'")
else:
super()._insert_overwrite_by_condition(
table_name,
source_queries,
target_columns_to_types,
where,
insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT,
)
def _truncate_table(self, table_name: TableName) -> None:
table = exp.to_table(table_name)
# Some trino connectors don't support truncate so we use delete.
self.execute(f"DELETE FROM {table.sql(dialect=self.dialect, identify=True)}")
def _get_data_objects(
self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None
) -> t.List[DataObject]:
"""
Returns all the data objects that exist in the given schema and optionally catalog.
"""
schema_name = to_schema(schema_name)
schema = schema_name.db
catalog = schema_name.catalog or self.get_current_catalog()
query = (
exp.select(
exp.column("table_catalog", table="t").as_("catalog"),
exp.column("table_schema", table="t").as_("schema"),
exp.column("table_name", table="t").as_("name"),
exp.case()
.when(
exp.column("name", table="mv").is_(exp.null()).not_(),
exp.Literal.string("materialized_view"),
)
.when(
exp.column("table_type", table="t").eq("BASE TABLE"),
exp.Literal.string("table"),
)
.else_(exp.column("table_type", table="t"))
.as_("type"),
)
.from_(exp.to_table(f"{catalog}.information_schema.tables", alias="t"))
.join(
exp.to_table("system.metadata.materialized_views", alias="mv"),
on=exp.and_(
exp.column("catalog_name", table="mv").eq(
exp.column("table_catalog", table="t")
),
exp.column("schema_name", table="mv").eq(exp.column("table_schema", table="t")),
exp.column("name", table="mv").eq(exp.column("table_name", table="t")),
),
join_type="left",
)
.where(
exp.and_(
exp.column("table_schema", table="t").eq(schema),
exp.or_(
exp.column("catalog_name", table="mv").is_(exp.null()),
exp.column("catalog_name", table="mv").eq(catalog),
),
exp.or_(
exp.column("schema_name", table="mv").is_(exp.null()),
exp.column("schema_name", table="mv").eq(schema),
),
)
)
)
if object_names:
query = query.where(exp.column("table_name", table="t").isin(*object_names))
df = self.fetchdf(query)
return [
DataObject(
catalog=row.catalog, # type: ignore
schema=row.schema, # type: ignore
name=row.name, # type: ignore
type=DataObjectType.from_str(row.type), # type: ignore
)
for row in df.itertuples()
]
def _df_to_source_queries(
self,
df: DF,
target_columns_to_types: t.Dict[str, exp.DataType],
batch_size: int,
target_table: TableName,
source_columns: t.Optional[t.List[str]] = None,
) -> t.List[SourceQuery]:
import pandas as pd
from pandas.api.types import is_datetime64_any_dtype # type: ignore
assert isinstance(df, pd.DataFrame)
source_columns_to_types = get_source_columns_to_types(
target_columns_to_types, source_columns
)
# Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in
# Pandas with that format, so we convert the column to a string with the proper format and CAST to
# timestamp in Trino.
for column, kind in source_columns_to_types.items():
dtype = df.dtypes[column]
if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None:
df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" "))
return super()._df_to_source_queries(
df, target_columns_to_types, batch_size, target_table, source_columns=source_columns
)
def _build_schema_exp(
self,
table: exp.Table,
target_columns_to_types: t.Dict[str, exp.DataType],
column_descriptions: t.Optional[t.Dict[str, str]] = None,
expressions: t.Optional[t.List[exp.PrimaryKey]] = None,
is_view: bool = False,
) -> exp.Schema:
if "delta_lake" in self.get_catalog_type_from_table(table):
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
return super()._build_schema_exp(
table, target_columns_to_types, column_descriptions, expressions, is_view
)
def _scd_type_2(
self,
target_table: TableName,
source_table: QueryOrDF,
unique_key: t.Sequence[exp.Expression],
valid_from_col: exp.Column,
valid_to_col: exp.Column,
execution_time: t.Union[TimeLike, exp.Column],
invalidate_hard_deletes: bool = True,
updated_at_col: t.Optional[exp.Column] = None,
check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None,
updated_at_as_valid_from: bool = False,
execution_time_as_valid_from: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
truncate: bool = False,
source_columns: t.Optional[t.List[str]] = None,
**kwargs: t.Any,
) -> None:
if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table(
target_table
):
target_columns_to_types = self._to_delta_ts(target_columns_to_types)
return super()._scd_type_2(
target_table,
source_table,
unique_key,
valid_from_col,
valid_to_col,
execution_time,
invalidate_hard_deletes,
updated_at_col,
check_columns,
updated_at_as_valid_from,
execution_time_as_valid_from,
target_columns_to_types,
table_description,
column_descriptions,
truncate,
source_columns,
**kwargs,
)
# delta_lake only supports two timestamp data types. This method converts other
# timestamp types to those two for use in DDL statements. Trino/delta automatically
# converts the data values to the correct type on write, so we only need to handle
# the column types in DDL.
# - `timestamp(6)` for non-timezone-aware
# - `timestamp(3) with time zone` for timezone-aware
# https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping
def _to_delta_ts(
self, columns_to_types: t.Dict[str, exp.DataType]
) -> t.Dict[str, exp.DataType]:
ts6 = exp.DataType.build("timestamp(6)")
ts3_tz = exp.DataType.build("timestamp(3) with time zone")
delta_columns_to_types = {
k: ts6 if v.is_type(exp.DataType.Type.TIMESTAMP) else v
for k, v in columns_to_types.items()
}
delta_columns_to_types = {
k: ts3_tz if v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v
for k, v in delta_columns_to_types.items()
}
return delta_columns_to_types
@retry(wait=wait_fixed(1), stop=stop_after_attempt(10), retry=retry_if_result(lambda v: not v))
def _block_until_table_exists(self, table_name: TableName) -> bool:
return self.table_exists(table_name)
def _create_schema(
self,
schema_name: SchemaName,
ignore_if_exists: bool,
warn_on_error: bool,
properties: t.List[exp.Expression],
kind: str,
) -> None:
if mapped_location := self._schema_location(schema_name):
properties.append(exp.LocationProperty(this=exp.Literal.string(mapped_location)))
return super()._create_schema(
schema_name=schema_name,
ignore_if_exists=ignore_if_exists,
warn_on_error=warn_on_error,
properties=properties,
kind=kind,
)
def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
expression: t.Optional[exp.Expression],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
table_kind: t.Optional[str] = None,
track_rows_processed: bool = True,
**kwargs: t.Any,
) -> None:
super()._create_table(
table_name_or_schema=table_name_or_schema,
expression=expression,
exists=exists,
replace=replace,
target_columns_to_types=target_columns_to_types,
table_description=table_description,
column_descriptions=column_descriptions,
table_kind=table_kind,
track_rows_processed=track_rows_processed,
**kwargs,
)
# extract the table name
if isinstance(table_name_or_schema, exp.Schema):
table_name = table_name_or_schema.this
assert isinstance(table_name, exp.Table)
else:
table_name = table_name_or_schema
if "hive" in self.get_catalog_type_from_table(table_name):
# the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads
# (even if metadata TTL is set to 0s)
# Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail
self._block_until_table_exists(table_name)
def _schema_location(self, schema_name: SchemaName) -> t.Optional[str]:
if mapping := self.schema_location_mapping:
schema = to_schema(schema_name)
match_key = schema.db
# only consider the catalog if it is present
if schema.catalog:
match_key = f"{schema.catalog}.{match_key}"
for k, v in mapping.items():
if re.match(k, match_key):
return v.replace("@{schema_name}", schema.db).replace(
"@{catalog_name}", schema.catalog
)
return None