Skip to content

Commit c68133a

Browse files
feat: Add support for converting dy.List, dy.Array to postgres array (#256)
1 parent 25cdb9b commit c68133a

File tree

4 files changed

+32
-12
lines changed

4 files changed

+32
-12
lines changed

dataframely/columns/array.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) QuantCo 2025-2025
1+
# Copyright (c) QuantCo 2025-2026
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from __future__ import annotations
@@ -97,8 +97,17 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
9797
}
9898

9999
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
100-
# NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future.
101-
raise NotImplementedError("SQL column cannot have 'Array' type.")
100+
match dialect.name:
101+
case "postgresql":
102+
# Note that the length of the array in each dimension is not supported in SQLAlchemy
103+
# That is because PostgreSQL does not enforce the length anyway
104+
return sa.ARRAY(
105+
self.inner.sqlalchemy_dtype(dialect), dimensions=len(self.shape)
106+
)
107+
case _:
108+
raise NotImplementedError(
109+
f"SQL column cannot have 'Array' type for dialect '{dialect}'."
110+
)
102111

103112
def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field:
104113
if shape:

dataframely/columns/binary.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) QuantCo 2025-2025
1+
# Copyright (c) QuantCo 2025-2026
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from __future__ import annotations
@@ -21,9 +21,11 @@ def dtype(self) -> pl.DataType:
2121
return pl.Binary()
2222

2323
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
24-
if dialect.name == "mssql":
25-
return sa.VARBINARY()
26-
return sa.LargeBinary()
24+
match dialect.name:
25+
case "mssql":
26+
return sa.VARBINARY()
27+
case _:
28+
return sa.LargeBinary()
2729

2830
@property
2931
def pyarrow_dtype(self) -> pa.DataType:

dataframely/columns/list.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) QuantCo 2025-2025
1+
# Copyright (c) QuantCo 2025-2026
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from __future__ import annotations
@@ -120,8 +120,13 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
120120
}
121121

122122
def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
123-
# NOTE: We might want to add support for PostgreSQL's ARRAY type or use JSON in the future.
124-
raise NotImplementedError("SQL column cannot have 'List' type.")
123+
match dialect.name:
124+
case "postgresql":
125+
return sa.ARRAY(self.inner.sqlalchemy_dtype(dialect))
126+
case _:
127+
raise NotImplementedError(
128+
f"SQL column cannot have 'List' type for dialect '{dialect}'."
129+
)
125130

126131
@property
127132
def pyarrow_dtype(self) -> pa.DataType:

tests/columns/test_sqlalchemy_columns.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def test_mssql_datatype(column: Column, datatype: str) -> None:
9595
(dy.String(regex="^[abc]{1,3}d$"), "VARCHAR(4)"),
9696
(dy.Enum(["foo", "bar"]), "CHAR(3)"),
9797
(dy.Enum(["a", "abc"]), "VARCHAR(3)"),
98+
(dy.List(dy.Integer()), "INTEGER[]"),
99+
(dy.List(dy.String(max_length=5)), "VARCHAR(5)[]"),
100+
(dy.Array(dy.Integer(), shape=5), "INTEGER[]"),
101+
(dy.Array(dy.String(max_length=5), shape=(2, 1)), "VARCHAR(5)[][]"),
98102
(dy.Struct({"a": dy.String(nullable=True)}), "JSONB"),
99103
],
100104
)
@@ -137,15 +141,15 @@ def test_sql_multiple_columns(dialect: Dialect) -> None:
137141
assert len(schema.to_sqlalchemy_columns(dialect)) == 2
138142

139143

140-
@pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()])
144+
@pytest.mark.parametrize("dialect", [MSDialect_pyodbc()])
141145
def test_raise_for_list_column(dialect: Dialect) -> None:
142146
with pytest.raises(
143147
NotImplementedError, match="SQL column cannot have 'List' type."
144148
):
145149
dy.List(dy.String()).sqlalchemy_dtype(dialect)
146150

147151

148-
@pytest.mark.parametrize("dialect", [MSDialect_pyodbc(), PGDialect_psycopg2()])
152+
@pytest.mark.parametrize("dialect", [MSDialect_pyodbc()])
149153
def test_raise_for_array_column(dialect: Dialect) -> None:
150154
with pytest.raises(
151155
NotImplementedError, match="SQL column cannot have 'Array' type."

0 commit comments

Comments
 (0)