Skip to content

Commit d4e15b3

Browse files
fix: ensure correct datatypes are fethed for RisingWave dialect
1 parent b81b109 commit d4e15b3

2 files changed

Lines changed: 69 additions & 1 deletion

File tree

sqlmesh/core/engine_adapter/risingwave.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
CommentCreationTable,
1515
)
1616

17+
from sqlmesh.utils.errors import SQLMeshError
1718

1819
if t.TYPE_CHECKING:
1920
from sqlmesh.core._typing import TableName
@@ -32,5 +33,31 @@ class RisingwaveEngineAdapter(PostgresEngineAdapter):
3233
SUPPORTS_TRANSACTIONS = False
3334
MAX_IDENTIFIER_LENGTH = None
3435

36+
def columns(
37+
self, table_name: TableName, include_pseudo_columns: bool = False
38+
) -> t.Dict[str, exp.DataType]:
39+
"""Fetches column names and types for the target_table"""
40+
table = exp.to_table(table_name)
41+
42+
sql = (
43+
exp.select("rw_columns.name AS column_name", "rw_columns.data_type AS data_type")
44+
.from_("rw_catalog.rw_columns")
45+
.join("rw_catalog.rw_relations", on="rw_relations.id=rw_columns.relation_id")
46+
.join("rw_catalog.rw_schemas", on="rw_schemas.id=rw_relations.schema_id")
47+
.where(exp.column("rw_relations.name", quoted=False).eq(table.alias_or_name))
48+
)
49+
50+
if table.args.get("db"):
51+
sql = sql.where(exp.column("rw_schemas.name").eq(table.args["db"].name))
52+
53+
self.execute(sql)
54+
resp = self.cursor.fetchall()
55+
if not resp:
56+
raise SQLMeshError(f"Could not get columns for table {table_name}. Table not found.")
57+
return {
58+
column_name: exp.DataType.build(data_type, dialect=self.dialect, udt=True)
59+
for column_name, data_type in resp
60+
}
61+
3562
def _truncate_table(self, table_name: TableName) -> None:
3663
return self.execute(exp.Delete(this=exp.to_table(table_name)))

tests/core/engine_adapter/test_risingwave.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unittest.mock import call
44

55
import pytest
6-
from sqlglot import parse_one
6+
from sqlglot import parse_one, exp
77
from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter
88

99
pytestmark = [pytest.mark.engine, pytest.mark.risingwave]
@@ -14,6 +14,47 @@ def adapter(make_mocked_engine_adapter):
1414
adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter)
1515
return adapter
1616

17+
def test_columns(adapter: t.Callable):
18+
adapter.cursor.fetchall.return_value = [
19+
("smallint_col","smallint"),
20+
("int_col","integer"),
21+
("bigint_col","bigint"),
22+
("ts_col","timestamp without time zone"),
23+
("tstz_col","timestamp with time zone"),
24+
("int_array_col","integer[]"),
25+
("vchar_col","character varying"),
26+
("struct_col","struct<nested_col integer>")
27+
]
28+
resp = adapter.columns("db.table")
29+
assert resp == {
30+
"smallint_col" : exp.DataType.build(exp.DataType.Type.SMALLINT,nested=False),
31+
"int_col" : exp.DataType.build(exp.DataType.Type.INT,nested=False),
32+
"bigint_col" : exp.DataType.build(exp.DataType.Type.BIGINT,nested=False),
33+
"ts_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMP,nested=False),
34+
"tstz_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ,nested=False),
35+
"int_array_col" : exp.DataType.build(
36+
exp.DataType.Type.ARRAY,
37+
expressions = [exp.DataType.build(exp.DataType.Type.INT,nested=False)],
38+
nested = True
39+
),
40+
"vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR),
41+
"struct_col": exp.DataType.build(
42+
exp.DataType.Type.STRUCT,
43+
expressions = [
44+
exp.ColumnDef(
45+
this = exp.Identifier(
46+
this = "nested_col",
47+
quoted = False
48+
),
49+
kind = exp.DataType.build(
50+
exp.DataType.Type.INT,
51+
nested = False
52+
)
53+
)
54+
],
55+
nested = True
56+
),
57+
}
1758

1859
def test_create_view(adapter: t.Callable):
1960
adapter.create_view("db.view", parse_one("SELECT 1"), replace=True)

0 commit comments

Comments
 (0)