Skip to content

Commit c19a86f

Browse files
test: Added integration test for risingwave columns function
1 parent 9c4a9ca commit c19a86f

1 file changed

Lines changed: 83 additions & 0 deletions

File tree

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import typing as t
2+
import pytest
3+
from sqlglot import exp
4+
from pytest import FixtureRequest
5+
from sqlmesh.core.engine_adapter import RisingwaveEngineAdapter
6+
from tests.core.engine_adapter.integration import (
7+
TestContext,
8+
generate_pytest_params,
9+
ENGINES_BY_NAME,
10+
IntegrationTestEngine,
11+
)
12+
13+
14+
pytestmark = [pytest.mark.risingwave, pytest.mark.engine, pytest.mark.slow]
15+
16+
@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["risingwave"])))
17+
def ctx(
18+
request: FixtureRequest,
19+
create_test_context: t.Callable[[IntegrationTestEngine,str,str], t.Iterable],
20+
) -> t.Iterable[TestContext]:
21+
yield from create_test_context(*request.param)
22+
23+
@pytest.fixture
24+
def engine_adapter(ctx: TestContext) -> RisingwaveEngineAdapter:
25+
assert isinstance(ctx.engine_adapter,RisingwaveEngineAdapter)
26+
return ctx.engine_adapter
27+
28+
@pytest.fixture
29+
def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str,exp.DataType]:
30+
base_types = {
31+
"smallint_col" : exp.DataType.build(exp.DataType.Type.SMALLINT,nested=False),
32+
"int_col" : exp.DataType.build(exp.DataType.Type.INT,nested = False),
33+
"bigint_col" : exp.DataType.build(exp.DataType.Type.BIGINT, nested = False),
34+
"ts_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested = False),
35+
"tstz_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested = False),
36+
"vchar_col" : exp.DataType.build(exp.DataType.Type.VARCHAR, nested = False),
37+
}
38+
# generate all arrays of base types
39+
arr_types = {
40+
f"{type_name}_arr_col" : exp.DataType.build(
41+
exp.DataType.Type.ARRAY,
42+
expressions = [base_type],
43+
nested = True,
44+
) for type_name,base_type in base_types.items()
45+
}
46+
# generate struct with all base types as nested columns
47+
struct_types = {
48+
"struct_col" : exp.DataType.build(
49+
exp.DataType.Type.STRUCT,
50+
expressions = [
51+
exp.ColumnDef(
52+
this = exp.Identifier(this=f"nested_{type_name}_col", quoted = False),
53+
kind = base_type
54+
)
55+
for type_name,base_type in base_types.items()
56+
],
57+
nested = True,
58+
)
59+
}
60+
return {
61+
**base_types,
62+
**arr_types,
63+
**struct_types
64+
}
65+
66+
def test_engine_adapter(ctx: TestContext):
67+
assert isinstance(ctx.engine_adapter,RisingwaveEngineAdapter)
68+
assert ctx.engine_adapter.fetchone("select 1")==(1,)
69+
70+
def test_engine_adapter_columns(ctx: TestContext,risingwave_columns_with_datatypes: t.Dict[str,exp.DataType]):
71+
table = ctx.table("TEST_COLUMNS")
72+
query_cols: t.List[str] = [f"NULL::{data_type.sql(dialect='risingwave')} AS {col_name}" for col_name,data_type in risingwave_columns_with_datatypes.items()]
73+
query: exp.Query = exp.maybe_parse(
74+
f"""
75+
SELECT
76+
{','.join(query_cols)}
77+
""",
78+
dialect="risingwave",
79+
)
80+
ctx.engine_adapter.ctas(table,query)
81+
82+
column_result = ctx.engine_adapter.columns(table)
83+
assert column_result == risingwave_columns_with_datatypes

0 commit comments

Comments
 (0)