|
13 | 13 |
|
14 | 14 | pytestmark = [pytest.mark.risingwave, pytest.mark.engine, pytest.mark.slow] |
15 | 15 |
|
| 16 | + |
16 | 17 | @pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["risingwave"]))) |
17 | 18 | def ctx( |
18 | 19 | request: FixtureRequest, |
19 | | - create_test_context: t.Callable[[IntegrationTestEngine,str,str], t.Iterable], |
| 20 | + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable], |
20 | 21 | ) -> t.Iterable[TestContext]: |
21 | 22 | yield from create_test_context(*request.param) |
22 | 23 |
|
| 24 | + |
23 | 25 | @pytest.fixture |
24 | 26 | def engine_adapter(ctx: TestContext) -> RisingwaveEngineAdapter: |
25 | | - assert isinstance(ctx.engine_adapter,RisingwaveEngineAdapter) |
| 27 | + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) |
26 | 28 | return ctx.engine_adapter |
27 | 29 |
|
| 30 | + |
28 | 31 | @pytest.fixture |
29 | | -def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str,exp.DataType]: |
| 32 | +def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str, exp.DataType]: |
30 | 33 | base_types = { |
31 | | - "smallint_col" : exp.DataType.build(exp.DataType.Type.SMALLINT,nested=False), |
32 | | - "int_col" : exp.DataType.build(exp.DataType.Type.INT,nested = False), |
33 | | - "bigint_col" : exp.DataType.build(exp.DataType.Type.BIGINT, nested = False), |
34 | | - "ts_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested = False), |
35 | | - "tstz_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested = False), |
36 | | - "vchar_col" : exp.DataType.build(exp.DataType.Type.VARCHAR, nested = False), |
| 34 | + "smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False), |
| 35 | + "int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False), |
| 36 | + "bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False), |
| 37 | + "ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False), |
| 38 | + "tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False), |
| 39 | + "vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR, nested=False), |
37 | 40 | } |
38 | 41 | # generate all arrays of base types |
39 | 42 | arr_types = { |
40 | | - f"{type_name}_arr_col" : exp.DataType.build( |
| 43 | + f"{type_name}_arr_col": exp.DataType.build( |
41 | 44 | exp.DataType.Type.ARRAY, |
42 | | - expressions = [base_type], |
43 | | - nested = True, |
44 | | - ) for type_name,base_type in base_types.items() |
| 45 | + expressions=[base_type], |
| 46 | + nested=True, |
| 47 | + ) |
| 48 | + for type_name, base_type in base_types.items() |
45 | 49 | } |
46 | 50 | # generate struct with all base types as nested columns |
47 | 51 | struct_types = { |
48 | | - "struct_col" : exp.DataType.build( |
| 52 | + "struct_col": exp.DataType.build( |
49 | 53 | exp.DataType.Type.STRUCT, |
50 | | - expressions = [ |
| 54 | + expressions=[ |
51 | 55 | exp.ColumnDef( |
52 | | - this = exp.Identifier(this=f"nested_{type_name}_col", quoted = False), |
53 | | - kind = base_type |
| 56 | + this=exp.Identifier(this=f"nested_{type_name}_col", quoted=False), |
| 57 | + kind=base_type, |
54 | 58 | ) |
55 | | - for type_name,base_type in base_types.items() |
| 59 | + for type_name, base_type in base_types.items() |
56 | 60 | ], |
57 | | - nested = True, |
| 61 | + nested=True, |
58 | 62 | ) |
59 | 63 | } |
60 | | - return { |
61 | | - **base_types, |
62 | | - **arr_types, |
63 | | - **struct_types |
64 | | - } |
| 64 | + return {**base_types, **arr_types, **struct_types} |
| 65 | + |
65 | 66 |
|
66 | 67 | def test_engine_adapter(ctx: TestContext): |
67 | | - assert isinstance(ctx.engine_adapter,RisingwaveEngineAdapter) |
68 | | - assert ctx.engine_adapter.fetchone("select 1")==(1,) |
| 68 | + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) |
| 69 | + assert ctx.engine_adapter.fetchone("select 1") == (1,) |
| 70 | + |
69 | 71 |
|
70 | | -def test_engine_adapter_columns(ctx: TestContext,risingwave_columns_with_datatypes: t.Dict[str,exp.DataType]): |
| 72 | +def test_engine_adapter_columns( |
| 73 | + ctx: TestContext, risingwave_columns_with_datatypes: t.Dict[str, exp.DataType] |
| 74 | +): |
71 | 75 | table = ctx.table("TEST_COLUMNS") |
72 | | - query_cols: t.List[str] = [f"NULL::{data_type.sql(dialect='risingwave')} AS {col_name}" for col_name,data_type in risingwave_columns_with_datatypes.items()] |
| 76 | + query_cols: t.List[str] = [ |
| 77 | + f"NULL::{data_type.sql(dialect='risingwave')} AS {col_name}" |
| 78 | + for col_name, data_type in risingwave_columns_with_datatypes.items() |
| 79 | + ] |
73 | 80 | query: exp.Query = exp.maybe_parse( |
74 | 81 | f""" |
75 | 82 | SELECT |
76 | | - {','.join(query_cols)} |
| 83 | + {",".join(query_cols)} |
77 | 84 | """, |
78 | 85 | dialect="risingwave", |
79 | 86 | ) |
80 | | - ctx.engine_adapter.ctas(table,query) |
| 87 | + ctx.engine_adapter.ctas(table, query) |
81 | 88 |
|
82 | 89 | column_result = ctx.engine_adapter.columns(table) |
83 | 90 | assert column_result == risingwave_columns_with_datatypes |
0 commit comments