|
6 | 6 |
|
7 | 7 | import pandas as pd |
8 | 8 | from sqlglot import exp |
9 | | - |
| 9 | +from sqlmesh.core.dialect import to_schema |
10 | 10 | from sqlmesh.core.engine_adapter.shared import ( |
11 | 11 | CatalogSupport, |
12 | 12 | DataObject, |
| 13 | + DataObjectType, |
13 | 14 | InsertOverwriteStrategy, |
14 | | - set_catalog, |
15 | 15 | SourceQuery, |
16 | 16 | ) |
17 | 17 | from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter |
|
27 | 27 | logger = logging.getLogger(__name__) |
28 | 28 |
|
29 | 29 |
|
30 | | -@set_catalog( |
31 | | - { |
32 | | - "_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG, |
33 | | - } |
34 | | -) |
35 | 30 | class DatabricksEngineAdapter(SparkEngineAdapter): |
36 | 31 | DIALECT = "databricks" |
37 | 32 | INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE |
@@ -251,7 +246,43 @@ def _set_spark_session_current_catalog(spark: PySparkSession) -> None: |
251 | 246 | def _get_data_objects( |
252 | 247 | self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None |
253 | 248 | ) -> t.List[DataObject]: |
254 | | - return super()._get_data_objects(schema_name, object_names=object_names) |
| 249 | + """ |
| 250 | + Returns all the data objects that exist in the given schema and catalog. |
| 251 | + """ |
| 252 | + schema = to_schema(schema_name) |
| 253 | + catalog_name = schema.catalog or self.get_current_catalog() |
| 254 | + query = ( |
| 255 | + exp.select( |
| 256 | + exp.column("table_name").as_("name"), |
| 257 | + exp.column("table_schema").as_("schema"), |
| 258 | + exp.column("table_catalog").as_("catalog"), |
| 259 | + exp.case(exp.column("table_type")) |
| 260 | + .when(exp.Literal.string("VIEW"), exp.Literal.string("view")) |
| 261 | + .when(exp.Literal.string("MATERIALIZED_VIEW"), exp.Literal.string("view")) |
| 262 | + .else_(exp.Literal.string("table")) |
| 263 | + .as_("type"), |
| 264 | + ) |
| 265 | + .from_( |
| 266 | + # always query `system` information_schema |
| 267 | + exp.table_("tables", "information_schema", "system") |
| 268 | + ) |
| 269 | + .where(exp.column("table_catalog").eq(catalog_name)) |
| 270 | + .where(exp.column("table_schema").eq(schema.db)) |
| 271 | + ) |
| 272 | + |
| 273 | + if object_names: |
| 274 | + query = query.where(exp.column("table_name").isin(*object_names)) |
| 275 | + |
| 276 | + df = self.fetchdf(query) |
| 277 | + return [ |
| 278 | + DataObject( |
| 279 | + catalog=row.catalog, # type: ignore |
| 280 | + schema=row.schema, # type: ignore |
| 281 | + name=row.name, # type: ignore |
| 282 | + type=DataObjectType.from_str(row.type), # type: ignore |
| 283 | + ) |
| 284 | + for row in df.itertuples() |
| 285 | + ] |
255 | 286 |
|
256 | 287 | def clone_table( |
257 | 288 | self, |
|
0 commit comments