Skip to content

Commit f71d66b

Browse files
gab23rgabrielAndreasAlbertQC
authored
fix: Ignore dy.Any columns in Schema.cast (#315)
Co-authored-by: gabriel <gabriel.g.robin@airbus.com> Co-authored-by: Andreas Albert <103571926+AndreasAlbertQC@users.noreply.github.com> Co-authored-by: Andreas Albert <andreas.albert@quantco.com>
1 parent 7b47919 commit f71d66b

4 files changed

Lines changed: 28 additions & 10 deletions

File tree

dataframely/schema.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,7 @@ def cast(
813813
the lazy frame's schema but also means that a call to :meth:`polars.LazyFrame.collect`
814814
further down the line might fail because of the cast and/or missing columns.
815815
"""
816-
lf = df.lazy().select(
817-
pl.col(name).cast(col.dtype) for name, col in cls.columns().items()
818-
)
816+
lf = match_to_schema(df.lazy(), cls, casting="strict")
819817
if isinstance(df, pl.DataFrame):
820818
return lf.collect() # type: ignore
821819
return lf # type: ignore

tests/collection/test_cast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
import polars as pl
5-
import polars.exceptions as plexc
65
import pytest
76

87
import dataframely as dy
8+
from dataframely.exc import SchemaError
99

1010

1111
class FirstSchema(dy.Schema):
@@ -48,12 +48,12 @@ def test_cast_invalid_members(df_type: type[pl.DataFrame] | type[pl.LazyFrame])
4848

4949
def test_cast_invalid_member_schema_eager() -> None:
5050
first = pl.DataFrame({"b": [3]})
51-
with pytest.raises(plexc.ColumnNotFoundError):
51+
with pytest.raises(SchemaError):
5252
Collection.cast({"first": first})
5353

5454

5555
def test_cast_invalid_member_schema_lazy() -> None:
5656
first = pl.LazyFrame({"b": [3]})
5757
collection = Collection.cast({"first": first})
58-
with pytest.raises(plexc.ColumnNotFoundError):
58+
with pytest.raises(SchemaError):
5959
collection.collect_all()

tests/column_types/test_any.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,9 @@ class AnySchema(dy.Schema):
2020
def test_any_dtype_passes(data: dict[str, Any]) -> None:
2121
df = pl.DataFrame(data)
2222
assert AnySchema.is_valid(df)
23+
24+
25+
def test_any_cast() -> None:
26+
df = pl.DataFrame({"a": 0})
27+
result = AnySchema.cast(df)
28+
assert result["a"].dtype == pl.Int64

tests/schema/test_cast.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Copyright (c) QuantCo 2025-2026
22
# SPDX-License-Identifier: BSD-3-Clause
3-
43
from typing import Any
54

65
import polars as pl
7-
import polars.exceptions as plexc
86
import pytest
97

108
import dataframely as dy
9+
from dataframely.exc import SchemaError
1110

1211

1312
class MySchema(dy.Schema):
@@ -34,12 +33,27 @@ def test_cast_valid(
3433

3534
def test_cast_invalid_schema_eager() -> None:
3635
df = pl.DataFrame({"a": [1]})
37-
with pytest.raises(plexc.ColumnNotFoundError):
36+
with pytest.raises(SchemaError):
3837
MySchema.cast(df)
3938

4039

4140
def test_cast_invalid_schema_lazy() -> None:
4241
lf = pl.LazyFrame({"a": [1]})
4342
lf = MySchema.cast(lf)
44-
with pytest.raises(plexc.ColumnNotFoundError):
43+
with pytest.raises(SchemaError):
4544
lf.collect()
45+
46+
47+
class IntegerSchema(dy.Schema):
48+
a = dy.Integer()
49+
50+
51+
@pytest.mark.parametrize("df_type", [pl.DataFrame, pl.LazyFrame])
52+
def test_cast_preserves_valid_dtype(
53+
df_type: type[pl.DataFrame] | type[pl.LazyFrame],
54+
) -> None:
55+
"""Test that cast doesn't change already valid dtypes (issue #318)."""
56+
df = df_type({"a": [1, 2, 3]}, schema={"a": pl.Int32})
57+
result = IntegerSchema.cast(df)
58+
# Int32 is valid for dy.Integer, so it should NOT be cast to Int64
59+
assert result.lazy().collect_schema()["a"] == pl.Int32

0 commit comments

Comments
 (0)