Skip to content

Commit 4e92fea

Browse files
Claudeborchero
andauthored
feat: Introduce pydantic conversion for schemas (#324)
Co-authored-by: anthropic-code-agent[bot] <242468646+Claude@users.noreply.github.com> Co-authored-by: borchero <22455425+borchero@users.noreply.github.com> Co-authored-by: Oliver Borchert <oliver.borchert@quantco.com>
1 parent 49121b8 commit 4e92fea

File tree

19 files changed

+702
-8
lines changed

19 files changed

+702
-8
lines changed

dataframely/columns/_base.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
import inspect
77
import sys
8+
import warnings
89
from abc import ABC, abstractmethod
910
from collections import Counter
1011
from collections.abc import Callable, Mapping, Sequence
11-
from typing import Any, TypeAlias, cast
12+
from typing import Annotated, Any, TypeAlias, cast
1213

1314
import polars as pl
1415

15-
from dataframely._compat import pa, sa, sa_TypeEngine
16+
from dataframely._compat import pa, pydantic, sa, sa_TypeEngine
1617
from dataframely._polars import PolarsDataType
1718
from dataframely.random import Generator
1819

@@ -222,6 +223,50 @@ def pyarrow_field(self, name: str) -> pa.Field:
222223
def pyarrow_dtype(self) -> pa.DataType:
223224
"""The :mod:`pyarrow` dtype equivalent of this column data type."""
224225

226+
# ----------------------------------- PYDANTIC ----------------------------------- #
227+
228+
def pydantic_field(self) -> Any:
229+
"""Obtain a pydantic field type for this column definition.
230+
231+
Returns:
232+
A pydantic-compatible type annotation that includes structured constraints
233+
(such as `min`, `max`, ...).
234+
235+
Warning:
236+
Custom checks are not translated to pydantic validators.
237+
"""
238+
if self.check is not None:
239+
warnings.warn(
240+
f"Custom checks for column '{self.name or self.__class__.__name__}' "
241+
"are not translated to pydantic constraints."
242+
)
243+
244+
python_type = self._python_type
245+
if self.nullable:
246+
python_type = python_type | None
247+
248+
field_kwargs = self._pydantic_field_kwargs()
249+
if field_kwargs:
250+
return Annotated[python_type, pydantic.Field(**field_kwargs)]
251+
return python_type
252+
253+
@property
254+
@abstractmethod
255+
def _python_type(self) -> Any:
256+
"""The native Python type corresponding to this column definition."""
257+
258+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
259+
"""Return kwargs for pydantic.Field initialization.
260+
261+
This method should be extended by subclasses and mixins to add their
262+
specific constraints. Subclasses should call super() and extend the
263+
returned dictionary.
264+
265+
Returns:
266+
A dictionary of kwargs to pass to pydantic.Field.
267+
"""
268+
return {}
269+
225270
# ------------------------------------ HELPER ------------------------------------ #
226271

227272
@property

dataframely/columns/_mixins.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,18 @@ def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]:
8080
result["max_exclusive"] = expr < self.max_exclusive # type: ignore
8181
return result
8282

83+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
84+
kwargs = super()._pydantic_field_kwargs()
85+
if self.min is not None:
86+
kwargs["ge"] = self.min
87+
if self.min_exclusive is not None:
88+
kwargs["gt"] = self.min_exclusive
89+
if self.max is not None:
90+
kwargs["le"] = self.max
91+
if self.max_exclusive is not None:
92+
kwargs["lt"] = self.max_exclusive
93+
return kwargs
94+
8395

8496
# ------------------------------------ IS IN MIXIN ----------------------------------- #
8597

dataframely/columns/any.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from __future__ import annotations
55

6+
from typing import Any as AnyType
7+
68
import polars as pl
79

810
from dataframely._compat import pa, sa, sa_mssql, sa_TypeEngine
@@ -77,5 +79,9 @@ def pyarrow_field(self, name: str) -> pa.Field:
7779
def pyarrow_dtype(self) -> pa.DataType:
7880
return pa.null()
7981

82+
@property
83+
def _python_type(self) -> AnyType:
84+
return AnyType
85+
8086
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
8187
return pl.repeat(None, n, dtype=pl.Null, eager=True)

dataframely/columns/array.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import math
77
import sys
8+
import warnings
89
from collections.abc import Sequence
910
from typing import Any, Literal, cast
1011

@@ -121,6 +122,23 @@ def _pyarrow_field_of_shape(self, shape: Sequence[int]) -> pa.Field:
121122
def pyarrow_dtype(self) -> pa.DataType:
122123
return self._pyarrow_field_of_shape(self.shape).type
123124

125+
@property
126+
def _python_type(self) -> Any:
127+
inner_type = self.inner.pydantic_field()
128+
return list[inner_type] # type: ignore
129+
130+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
131+
if len(self.shape) != 1:
132+
warnings.warn(
133+
"Multi-dimensional arrays are flattened for pydantic validation."
134+
)
135+
136+
return {
137+
**super()._pydantic_field_kwargs(),
138+
"min_length": math.prod(self.shape),
139+
"max_length": math.prod(self.shape),
140+
}
141+
124142
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
125143
# Sample the inner elements in a flat series
126144
n_elements = n * math.prod(self.shape)

dataframely/columns/binary.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from __future__ import annotations
55

6+
from typing import Any
7+
68
import polars as pl
79

810
from dataframely._compat import pa, sa, sa_TypeEngine
@@ -31,6 +33,10 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
3133
def pyarrow_dtype(self) -> pa.DataType:
3234
return pa.large_binary()
3335

36+
@property
37+
def _python_type(self) -> Any:
38+
return bytes
39+
3440
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
3541
return generator.sample_binary(
3642
n,

dataframely/columns/bool.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from __future__ import annotations
55

6+
from typing import Any
7+
68
import polars as pl
79

810
from dataframely._compat import pa, sa, sa_TypeEngine
@@ -27,5 +29,9 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
2729
def pyarrow_dtype(self) -> pa.DataType:
2830
return pa.bool_()
2931

32+
@property
33+
def _python_type(self) -> Any:
34+
return bool
35+
3036
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
3137
return generator.sample_bool(n, null_probability=self._null_probability)

dataframely/columns/categorical.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
7171
def pyarrow_dtype(self) -> pa.DataType:
7272
return pa.dictionary(pa.uint32(), pa.large_string())
7373

74+
@property
75+
def _python_type(self) -> Any:
76+
return str
77+
7478
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
7579
# We simply sample low-cardinality strings here
7680
return generator.sample_string(

dataframely/columns/datetime.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from __future__ import annotations
55

66
import datetime as dt
7+
import warnings
78
from typing import Any, cast
89

910
import polars as pl
@@ -132,6 +133,16 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
132133
def pyarrow_dtype(self) -> pa.DataType:
133134
return pa.date32()
134135

136+
@property
137+
def _python_type(self) -> Any:
138+
return dt.date
139+
140+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
141+
if self.resolution is not None:
142+
warnings.warn("Date resolution is not translated to a pydantic constraint.")
143+
144+
return super()._pydantic_field_kwargs()
145+
135146
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
136147
return generator.sample_date(
137148
n,
@@ -261,6 +272,16 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
261272
def pyarrow_dtype(self) -> pa.DataType:
262273
return pa.time64("ns")
263274

275+
@property
276+
def _python_type(self) -> Any:
277+
return dt.time
278+
279+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
280+
if self.resolution is not None:
281+
warnings.warn("Time resolution is not translated to a pydantic constraint.")
282+
283+
return super()._pydantic_field_kwargs()
284+
264285
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
265286
return generator.sample_time(
266287
n,
@@ -394,6 +415,22 @@ def pyarrow_dtype(self) -> pa.DataType:
394415
)
395416
return pa.timestamp(self.time_unit, time_zone)
396417

418+
@property
419+
def _python_type(self) -> Any:
420+
return dt.datetime
421+
422+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
423+
if self.resolution is not None:
424+
warnings.warn(
425+
"Datetime resolution is not translated to a pydantic constraint."
426+
)
427+
if self.time_zone is not None:
428+
warnings.warn(
429+
"Datetime time zone is not translated to a pydantic constraint."
430+
)
431+
432+
return super()._pydantic_field_kwargs()
433+
397434
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
398435
return generator.sample_datetime(
399436
n,
@@ -531,6 +568,18 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine:
531568
def pyarrow_dtype(self) -> pa.DataType:
532569
return pa.duration(self.time_unit)
533570

571+
@property
572+
def _python_type(self) -> Any:
573+
return dt.timedelta
574+
575+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
576+
if self.resolution is not None:
577+
warnings.warn(
578+
"Duration resolution is not translated to a pydantic constraint."
579+
)
580+
581+
return super()._pydantic_field_kwargs()
582+
534583
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
535584
# NOTE: If no duration is specified, we default to 100 years
536585
return generator.sample_duration(

dataframely/columns/decimal.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,16 @@ def pyarrow_dtype(self) -> pa.DataType:
128128
# We do not use decimal256 since its values cannot be represented in SQL Server.
129129
return pa.decimal128(self.precision or 38, self.scale)
130130

131+
@property
132+
def _python_type(self) -> Any:
133+
return decimal.Decimal
134+
135+
def _pydantic_field_kwargs(self) -> dict[str, Any]:
136+
return {
137+
**super()._pydantic_field_kwargs(),
138+
"decimal_places": self.scale,
139+
}
140+
131141
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
132142
# NOTE: Default precision to 38 for sampling, just like for SQL and Pyarrow
133143
precision = self.precision or 38

dataframely/columns/enum.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import enum
77
from collections.abc import Iterable
88
from inspect import isclass
9-
from typing import Any
9+
from typing import Any, Literal
1010

1111
import polars as pl
1212

@@ -95,6 +95,10 @@ def pyarrow_dtype(self) -> pa.DataType:
9595
dtype = pa.uint32()
9696
return pa.dictionary(dtype, pa.large_string())
9797

98+
@property
99+
def _python_type(self) -> Any:
100+
return Literal[tuple(self.categories)]
101+
98102
def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series:
99103
return generator.sample_choice(
100104
n,

0 commit comments

Comments
 (0)