diff --git a/sqlmesh/core/engine_adapter/risingwave.py b/sqlmesh/core/engine_adapter/risingwave.py index f32ce2f457..fdcee90f0f 100644 --- a/sqlmesh/core/engine_adapter/risingwave.py +++ b/sqlmesh/core/engine_adapter/risingwave.py @@ -14,6 +14,7 @@ CommentCreationTable, ) +from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName @@ -32,5 +33,37 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter): SUPPORTS_TRANSACTIONS = False MAX_IDENTIFIER_LENGTH = None + def columns( + self, table_name: TableName, include_pseudo_columns: bool = False + ) -> t.Dict[str, exp.DataType]: + """Fetches column names and types for the target_table""" + table = exp.to_table(table_name) + + sql = ( + exp.select("rw_columns.name AS column_name", "rw_columns.data_type AS data_type") + .from_("rw_catalog.rw_columns") + .join("rw_catalog.rw_relations", on="rw_relations.id=rw_columns.relation_id") + .join("rw_catalog.rw_schemas", on="rw_schemas.id=rw_relations.schema_id") + .where( + exp.and_( + exp.column("name", table="rw_relations").eq(table.alias_or_name), + exp.column("name", table="rw_columns").neq("_row_id"), + exp.column("name", table="rw_columns").neq("_rw_timestamp"), + ) + ) + ) + + if table.db: + sql = sql.where(exp.column("name", table="rw_schemas").eq(table.db)) + + self.execute(sql) + resp = self.cursor.fetchall() + if not resp: + raise SQLMeshError(f"Could not get columns for table {table_name}. Table not found.") + return { + column_name: exp.DataType.build(data_type, dialect=self.dialect, udt=True) + for column_name, data_type in resp + } + def _truncate_table(self, table_name: TableName) -> None: return self.execute(exp.Delete(this=exp.to_table(table_name))) diff --git a/tests/core/engine_adapter/integration/test_integration_risingwave.py b/tests/core/engine_adapter/integration/test_integration_risingwave.py new file mode 100644 index 0000000000..76b3d20a7c --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_risingwave.py @@ -0,0 +1,82 @@ +import typing as t +import pytest +from sqlglot import exp +from pytest import FixtureRequest +from sqlmesh.core.engine_adapter import RisingwaveEngineAdapter +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["risingwave"]))) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> RisingwaveEngineAdapter: + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) + return ctx.engine_adapter + + +@pytest.fixture +def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str, exp.DataType]: + base_types = { + "smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False), + "int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False), + "bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False), + "ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False), + "tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False), + "vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR, nested=False), + } + # generate all arrays of base types + arr_types = { + f"{type_name}_arr_col": exp.DataType.build( + exp.DataType.Type.ARRAY, + expressions=[base_type], + nested=True, + ) + for type_name, base_type in base_types.items() + } + # generate struct with all base types as nested columns + struct_types = { + "struct_col": exp.DataType.build( + exp.DataType.Type.STRUCT, + expressions=[ + exp.ColumnDef( + this=exp.Identifier(this=f"nested_{type_name}_col", quoted=False), + kind=base_type, + ) + for type_name, base_type in base_types.items() + ], + nested=True, + ) + } + return {**base_types, **arr_types, **struct_types} + + +def test_engine_adapter(ctx: TestContext): + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) + assert ctx.engine_adapter.fetchone("select 1") == (1,) + + +def test_engine_adapter_columns( + ctx: TestContext, risingwave_columns_with_datatypes: t.Dict[str, exp.DataType] +): + table = ctx.table("TEST_COLUMNS") + query = exp.select( + *[ + exp.cast(exp.null(), dtype).as_(name) + for name, dtype in risingwave_columns_with_datatypes.items() + ] + ) + ctx.engine_adapter.ctas(table, query) + + column_result = ctx.engine_adapter.columns(table) + assert column_result == risingwave_columns_with_datatypes diff --git a/tests/core/engine_adapter/test_risingwave.py b/tests/core/engine_adapter/test_risingwave.py index 6718690283..ed3cd77a3f 100644 --- a/tests/core/engine_adapter/test_risingwave.py +++ b/tests/core/engine_adapter/test_risingwave.py @@ -3,7 +3,7 @@ from unittest.mock import call import pytest -from sqlglot import parse_one +from sqlglot import parse_one, exp from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter pytestmark = [pytest.mark.engine, pytest.mark.risingwave] @@ -15,6 +15,43 @@ def adapter(make_mocked_engine_adapter): return adapter +def test_columns(adapter: t.Callable): + adapter.cursor.fetchall.return_value = [ + ("smallint_col", "smallint"), + ("int_col", "integer"), + ("bigint_col", "bigint"), + ("ts_col", "timestamp without time zone"), + ("tstz_col", "timestamp with time zone"), + ("int_array_col", "integer[]"), + ("vchar_col", "character varying"), + ("struct_col", "struct"), + ] + resp = adapter.columns("db.table") + assert resp == { + "smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False), + "int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False), + "bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False), + "ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False), + "tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False), + "int_array_col": exp.DataType.build( + exp.DataType.Type.ARRAY, + expressions=[exp.DataType.build(exp.DataType.Type.INT, nested=False)], + nested=True, + ), + "vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR), + "struct_col": exp.DataType.build( + exp.DataType.Type.STRUCT, + expressions=[ + exp.ColumnDef( + this=exp.Identifier(this="nested_col", quoted=False), + kind=exp.DataType.build(exp.DataType.Type.INT, nested=False), + ) + ], + nested=True, + ), + } + + def test_create_view(adapter: t.Callable): adapter.create_view("db.view", parse_one("SELECT 1"), replace=True) adapter.create_view("db.view", parse_one("SELECT 1"), replace=False)