Skip to content

Commit a1d4073

Browse files
fix: formatted new integration test
1 parent c19a86f commit a1d4073

1 file changed

Lines changed: 37 additions & 30 deletions

File tree

tests/core/engine_adapter/integration/test_integration_risingwave.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,71 +13,78 @@
1313

1414
pytestmark = [pytest.mark.risingwave, pytest.mark.engine, pytest.mark.slow]
1515

16+
1617
@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["risingwave"])))
1718
def ctx(
1819
request: FixtureRequest,
19-
create_test_context: t.Callable[[IntegrationTestEngine,str,str], t.Iterable],
20+
create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable],
2021
) -> t.Iterable[TestContext]:
2122
yield from create_test_context(*request.param)
2223

24+
2325
@pytest.fixture
2426
def engine_adapter(ctx: TestContext) -> RisingwaveEngineAdapter:
25-
assert isinstance(ctx.engine_adapter,RisingwaveEngineAdapter)
27+
assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter)
2628
return ctx.engine_adapter
2729

30+
2831
@pytest.fixture
29-
def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str,exp.DataType]:
32+
def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str, exp.DataType]:
3033
base_types = {
31-
"smallint_col" : exp.DataType.build(exp.DataType.Type.SMALLINT,nested=False),
32-
"int_col" : exp.DataType.build(exp.DataType.Type.INT,nested = False),
33-
"bigint_col" : exp.DataType.build(exp.DataType.Type.BIGINT, nested = False),
34-
"ts_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested = False),
35-
"tstz_col" : exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested = False),
36-
"vchar_col" : exp.DataType.build(exp.DataType.Type.VARCHAR, nested = False),
34+
"smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False),
35+
"int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False),
36+
"bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False),
37+
"ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False),
38+
"tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False),
39+
"vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR, nested=False),
3740
}
3841
# generate all arrays of base types
3942
arr_types = {
40-
f"{type_name}_arr_col" : exp.DataType.build(
43+
f"{type_name}_arr_col": exp.DataType.build(
4144
exp.DataType.Type.ARRAY,
42-
expressions = [base_type],
43-
nested = True,
44-
) for type_name,base_type in base_types.items()
45+
expressions=[base_type],
46+
nested=True,
47+
)
48+
for type_name, base_type in base_types.items()
4549
}
4650
# generate struct with all base types as nested columns
4751
struct_types = {
48-
"struct_col" : exp.DataType.build(
52+
"struct_col": exp.DataType.build(
4953
exp.DataType.Type.STRUCT,
50-
expressions = [
54+
expressions=[
5155
exp.ColumnDef(
52-
this = exp.Identifier(this=f"nested_{type_name}_col", quoted = False),
53-
kind = base_type
56+
this=exp.Identifier(this=f"nested_{type_name}_col", quoted=False),
57+
kind=base_type,
5458
)
55-
for type_name,base_type in base_types.items()
59+
for type_name, base_type in base_types.items()
5660
],
57-
nested = True,
61+
nested=True,
5862
)
5963
}
60-
return {
61-
**base_types,
62-
**arr_types,
63-
**struct_types
64-
}
64+
return {**base_types, **arr_types, **struct_types}
65+
6566

6667
def test_engine_adapter(ctx: TestContext):
67-
assert isinstance(ctx.engine_adapter,RisingwaveEngineAdapter)
68-
assert ctx.engine_adapter.fetchone("select 1")==(1,)
68+
assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter)
69+
assert ctx.engine_adapter.fetchone("select 1") == (1,)
70+
6971

70-
def test_engine_adapter_columns(ctx: TestContext,risingwave_columns_with_datatypes: t.Dict[str,exp.DataType]):
72+
def test_engine_adapter_columns(
73+
ctx: TestContext, risingwave_columns_with_datatypes: t.Dict[str, exp.DataType]
74+
):
7175
table = ctx.table("TEST_COLUMNS")
72-
query_cols: t.List[str] = [f"NULL::{data_type.sql(dialect='risingwave')} AS {col_name}" for col_name,data_type in risingwave_columns_with_datatypes.items()]
76+
query_cols: t.List[str] = [
77+
f"NULL::{data_type.sql(dialect='risingwave')} AS {col_name}"
78+
for col_name, data_type in risingwave_columns_with_datatypes.items()
79+
]
7380
query: exp.Query = exp.maybe_parse(
7481
f"""
7582
SELECT
76-
{','.join(query_cols)}
83+
{",".join(query_cols)}
7784
""",
7885
dialect="risingwave",
7986
)
80-
ctx.engine_adapter.ctas(table,query)
87+
ctx.engine_adapter.ctas(table, query)
8188

8289
column_result = ctx.engine_adapter.columns(table)
8390
assert column_result == risingwave_columns_with_datatypes

0 commit comments

Comments
 (0)