Skip to content

Commit 6e06c73

Browse files
gab23rgabrielAndreasAlbertQC
authored
feat: Add dy.infer_schema (#294)
Co-authored-by: gabriel <gabriel.g.robin@airbus.com> Co-authored-by: Andreas Albert <andreas.albert@quantco.com>
1 parent 2534bf7 commit 6e06c73

File tree

9 files changed

+630
-3
lines changed

9 files changed

+630
-3
lines changed

dataframely/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"deserialize_schema",
7979
"read_parquet_metadata_schema",
8080
"read_parquet_metadata_collection",
81+
"infer_schema",
8182
"Any",
8283
"Binary",
8384
"Bool",
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright (c) QuantCo 2025-2026
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from .infer_schema import infer_schema
5+
6+
__all__ = ["infer_schema"]
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) QuantCo 2025-2026
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
"""Infer schema from a Polars DataFrame."""
4+
5+
from __future__ import annotations
6+
7+
import keyword
8+
import re
9+
10+
import polars as pl
11+
12+
_POLARS_DTYPE_MAP: dict[type[pl.DataType], str] = {
13+
pl.Boolean: "Bool",
14+
pl.Int8: "Int8",
15+
pl.Int16: "Int16",
16+
pl.Int32: "Int32",
17+
pl.Int64: "Int64",
18+
pl.UInt8: "UInt8",
19+
pl.UInt16: "UInt16",
20+
pl.UInt32: "UInt32",
21+
pl.UInt64: "UInt64",
22+
pl.Float32: "Float32",
23+
pl.Float64: "Float64",
24+
pl.String: "String",
25+
pl.Binary: "Binary",
26+
pl.Date: "Date",
27+
pl.Time: "Time",
28+
pl.Object: "Object",
29+
pl.Categorical: "Categorical",
30+
pl.Duration: "Duration",
31+
pl.Datetime: "Datetime",
32+
pl.Decimal: "Decimal",
33+
pl.Enum: "Enum",
34+
pl.List: "List",
35+
pl.Array: "Array",
36+
pl.Struct: "Struct",
37+
}
38+
39+
40+
def infer_schema(
41+
df: pl.DataFrame,
42+
schema_name: str = "Schema",
43+
) -> str:
44+
"""Infer a dataframely schema from a Polars DataFrame.
45+
46+
This function inspects a DataFrame's schema and generates corresponding
47+
dataframely Schema code as a string.
48+
49+
Args:
50+
df: The Polars DataFrame to infer the schema from.
51+
schema_name: The name for the generated schema class.
52+
53+
Returns:
54+
The schema code as a string.
55+
56+
Example:
57+
>>> import polars as pl
58+
>>> from dataframely.experimental import infer_schema
59+
>>> df = pl.DataFrame({
60+
... "name": ["Alice", "Bob"],
61+
... "age": [25, 30],
62+
... "score": [95.5, None],
63+
... })
64+
>>> print(infer_schema(df, "PersonSchema"))
65+
class PersonSchema(dy.Schema):
66+
name = dy.String()
67+
age = dy.Int64()
68+
score = dy.Float64(nullable=True)
69+
70+
Attention:
71+
This functionality is considered unstable. It may be changed at any time
72+
without it being considered a breaking change.
73+
74+
Raises:
75+
ValueError: If ``schema_name`` is not a valid Python identifier.
76+
"""
77+
if not schema_name.isidentifier():
78+
msg = f"schema_name must be a valid Python identifier, got {schema_name!r}"
79+
raise ValueError(msg)
80+
81+
return _generate_schema_code(df, schema_name)
82+
83+
84+
def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str:
85+
"""Generate schema code string from a DataFrame."""
86+
lines = [f"class {schema_name}(dy.Schema):"]
87+
used_identifiers: set[str] = set()
88+
89+
for col_index, (col_name, series) in enumerate(df.to_dict().items()):
90+
attr_name = _make_valid_identifier(col_name, col_index)
91+
# Make sure yes have no duplicates
92+
if attr_name in used_identifiers:
93+
# Remove trailing "_" if exists as it will be included in the suffix anyway
94+
if attr_name.endswith("_"):
95+
attr_name = attr_name[:-1]
96+
idx = 1
97+
while f"{attr_name}_{idx}" in used_identifiers:
98+
idx += 1
99+
attr_name = f"{attr_name}_{idx}"
100+
used_identifiers.add(attr_name)
101+
alias = col_name if attr_name != col_name else None
102+
col_code = _dtype_to_column_code(series, alias=alias)
103+
lines.append(f" {attr_name} = {col_code}")
104+
105+
return "\n".join(lines)
106+
107+
108+
def _make_valid_identifier(name: str, col_index: int) -> str:
109+
"""Convert a string to a valid Python identifier."""
110+
# Replace invalid characters with underscores
111+
valid_identifier = re.sub(r"[^a-zA-Z0-9_]", "_", name)
112+
113+
# Handle empty name or name with only special characters ones with simple "_"
114+
if set(valid_identifier).issubset({"_"}):
115+
return f"column_{col_index}"
116+
# Ensure it doesn't start with a digit
117+
if valid_identifier[0].isdigit():
118+
return "_" + valid_identifier
119+
if keyword.iskeyword(valid_identifier):
120+
return valid_identifier + "_"
121+
return valid_identifier
122+
123+
124+
def _get_dtype_args(dtype: pl.DataType, series: pl.Series) -> list[str]:
125+
"""Get extra arguments for parameterized types."""
126+
if isinstance(dtype, pl.Datetime):
127+
args = []
128+
if dtype.time_zone is not None:
129+
args.append(f'time_zone="{dtype.time_zone}"')
130+
if dtype.time_unit != "us":
131+
args.append(f'time_unit="{dtype.time_unit}"')
132+
return args
133+
134+
if isinstance(dtype, pl.Duration):
135+
if dtype.time_unit != "us": # us is the default
136+
return [f'time_unit="{dtype.time_unit}"']
137+
138+
if isinstance(dtype, pl.Decimal):
139+
args = []
140+
if dtype.precision is not None:
141+
args.append(f"precision={dtype.precision}")
142+
if dtype.scale != 0:
143+
args.append(f"scale={dtype.scale}")
144+
return args
145+
146+
if isinstance(dtype, pl.Enum):
147+
return [repr(dtype.categories.to_list())]
148+
149+
if isinstance(dtype, pl.List):
150+
return [_dtype_to_column_code(series.explode())]
151+
152+
if isinstance(dtype, pl.Array):
153+
return [_dtype_to_column_code(series.explode()), f"shape={dtype.size}"]
154+
155+
if isinstance(dtype, pl.Struct):
156+
fields_parts = []
157+
for field in dtype.fields:
158+
field_code = _dtype_to_column_code(series.struct.field(field.name))
159+
fields_parts.append(f'"{field.name}": {field_code}')
160+
return ["{" + ", ".join(fields_parts) + "}"]
161+
162+
return []
163+
164+
165+
def _format_args(*args: str, nullable: bool = False, alias: str | None = None) -> str:
166+
"""Format arguments for column constructor."""
167+
all_args = list(args)
168+
if nullable:
169+
all_args.append("nullable=True")
170+
if alias is not None:
171+
all_args.append(f'alias="{alias}"')
172+
return ", ".join(all_args)
173+
174+
175+
def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str:
176+
"""Convert a Polars Series to dataframely column constructor code."""
177+
dtype = series.dtype
178+
nullable = series.null_count() > 0
179+
dy_name = _POLARS_DTYPE_MAP.get(type(dtype))
180+
181+
if dy_name is None:
182+
return f"dy.Any({_format_args(alias=alias)})"
183+
184+
args = _get_dtype_args(dtype, series)
185+
return f"dy.{dy_name}({_format_args(*args, nullable=nullable, alias=alias)})"

docs/api/experimental/index.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
=============
2+
Experimental
3+
=============
4+
5+
.. currentmodule:: dataframely
6+
.. autosummary::
7+
:toctree: _gen/
8+
:nosignatures:
9+
10+
experimental.infer_schema

docs/api/index.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,12 @@ API Reference
4747
:maxdepth: 1
4848

4949
misc/index
50+
51+
.. grid::
52+
53+
.. grid-item-card::
54+
55+
.. toctree::
56+
:maxdepth: 1
57+
58+
experimental/index

docs/guides/faq.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ thinking, please add it here.
55

66
## How do I define additional unique keys in a {class}`~dataframely.Schema`?
77

8-
By default, `dataframely` only supports defining a single non-nullable (composite) primary key in :class:
8+
By default, `dataframely` only supports defining a single non-nullable (composite) primary key in {class}
99
`~dataframely.Schema`.
1010
However, in some scenarios it may be useful to define additional unique keys (which support nullable fields and/or which
1111
are additionally unique).
@@ -34,6 +34,13 @@ class UserSchema(dy.Schema):
3434

3535
See our documentation on [group rules](./quickstart.md#group-rules).
3636

37+
## How do I get a {class}`~dataframely.Schema` for my dataframe?
38+
39+
You can use {func}`dataframely.experimental.infer_schema` to get a basic {class}`~dataframely.Schema` definition for
40+
your dataframe. The function infers column names, types and nullability from the dataframe and returns a code
41+
representation of a {class}`~dataframely.Schema`
42+
Starting from this definition, you can then refine the schema by adding additional rules, adjusting types, etc.
43+
3744
## What versions of `polars` does `dataframely` support?
3845

3946
Our CI automatically tests `dataframely` for a minimal supported version of `polars`, which is currently `1.35.*`,

docs/guides/migration/index.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ Users can disable `FutureWarnings` either through
3737
builtins from tools
3838
like [pytest](https://docs.pytest.org/en/stable/how-to/capture-warnings.html#controlling-warnings),
3939
or by setting the `DATAFRAMELY_NO_FUTURE_WARNINGS` environment variable to `true` or `1`.
40+
41+
## Experimental features
42+
43+
Experimental features are published in a dedicated namespace `dataframely.experimental`.
44+
The versioning policy above does not apply to this namespace, and we may introduce breaking changes to experimental
45+
features in minor releases.

docs/guides/quickstart.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ print(failure.counts())
175175
```
176176

177177
In this case, `good` remains to be a `dy.DataFrame[HouseSchema]`, albeit with potentially fewer rows than `df`.
178-
The `failure` object is of type :class:`~dataframely.FailureInfo` and provides means to inspect
178+
The `failure` object is of type {class}`~dataframely.FailureInfo` and provides means to inspect
179179
the reasons for validation failures for invalid rows.
180180

181181
Given the example data above and the schema that we defined, we know that rows 2, 3, 4, and 5 are invalid (0-indexed):
@@ -185,7 +185,7 @@ Given the example data above and the schema that we defined, we know that rows 2
185185
- Row 4 violates both of the rules above
186186
- Row 5 violates the reasonable bathroom to bedroom ratio
187187

188-
Using the `counts` method on the :class:`~dataframely.FailureInfo` object will result in the following dictionary:
188+
Using the `counts` method on the {class}`~dataframely.FailureInfo` object will result in the following dictionary:
189189

190190
```python
191191
{

0 commit comments

Comments
 (0)