|
| 1 | +# Copyright (c) QuantCo 2025-2026 |
| 2 | +# SPDX-License-Identifier: BSD-3-Clause |
| 3 | +"""Infer schema from a Polars DataFrame.""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import keyword |
| 8 | +import re |
| 9 | + |
| 10 | +import polars as pl |
| 11 | + |
| 12 | +_POLARS_DTYPE_MAP: dict[type[pl.DataType], str] = { |
| 13 | + pl.Boolean: "Bool", |
| 14 | + pl.Int8: "Int8", |
| 15 | + pl.Int16: "Int16", |
| 16 | + pl.Int32: "Int32", |
| 17 | + pl.Int64: "Int64", |
| 18 | + pl.UInt8: "UInt8", |
| 19 | + pl.UInt16: "UInt16", |
| 20 | + pl.UInt32: "UInt32", |
| 21 | + pl.UInt64: "UInt64", |
| 22 | + pl.Float32: "Float32", |
| 23 | + pl.Float64: "Float64", |
| 24 | + pl.String: "String", |
| 25 | + pl.Binary: "Binary", |
| 26 | + pl.Date: "Date", |
| 27 | + pl.Time: "Time", |
| 28 | + pl.Object: "Object", |
| 29 | + pl.Categorical: "Categorical", |
| 30 | + pl.Duration: "Duration", |
| 31 | + pl.Datetime: "Datetime", |
| 32 | + pl.Decimal: "Decimal", |
| 33 | + pl.Enum: "Enum", |
| 34 | + pl.List: "List", |
| 35 | + pl.Array: "Array", |
| 36 | + pl.Struct: "Struct", |
| 37 | +} |
| 38 | + |
| 39 | + |
| 40 | +def infer_schema( |
| 41 | + df: pl.DataFrame, |
| 42 | + schema_name: str = "Schema", |
| 43 | +) -> str: |
| 44 | + """Infer a dataframely schema from a Polars DataFrame. |
| 45 | +
|
| 46 | + This function inspects a DataFrame's schema and generates corresponding |
| 47 | + dataframely Schema code as a string. |
| 48 | +
|
| 49 | + Args: |
| 50 | + df: The Polars DataFrame to infer the schema from. |
| 51 | + schema_name: The name for the generated schema class. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + The schema code as a string. |
| 55 | +
|
| 56 | + Example: |
| 57 | + >>> import polars as pl |
| 58 | + >>> from dataframely.experimental import infer_schema |
| 59 | + >>> df = pl.DataFrame({ |
| 60 | + ... "name": ["Alice", "Bob"], |
| 61 | + ... "age": [25, 30], |
| 62 | + ... "score": [95.5, None], |
| 63 | + ... }) |
| 64 | + >>> print(infer_schema(df, "PersonSchema")) |
| 65 | + class PersonSchema(dy.Schema): |
| 66 | + name = dy.String() |
| 67 | + age = dy.Int64() |
| 68 | + score = dy.Float64(nullable=True) |
| 69 | +
|
| 70 | + Attention: |
| 71 | + This functionality is considered unstable. It may be changed at any time |
| 72 | + without it being considered a breaking change. |
| 73 | +
|
| 74 | + Raises: |
| 75 | + ValueError: If ``schema_name`` is not a valid Python identifier. |
| 76 | + """ |
| 77 | + if not schema_name.isidentifier(): |
| 78 | + msg = f"schema_name must be a valid Python identifier, got {schema_name!r}" |
| 79 | + raise ValueError(msg) |
| 80 | + |
| 81 | + return _generate_schema_code(df, schema_name) |
| 82 | + |
| 83 | + |
| 84 | +def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str: |
| 85 | + """Generate schema code string from a DataFrame.""" |
| 86 | + lines = [f"class {schema_name}(dy.Schema):"] |
| 87 | + used_identifiers: set[str] = set() |
| 88 | + |
| 89 | + for col_index, (col_name, series) in enumerate(df.to_dict().items()): |
| 90 | + attr_name = _make_valid_identifier(col_name, col_index) |
| 91 | + # Make sure yes have no duplicates |
| 92 | + if attr_name in used_identifiers: |
| 93 | + # Remove trailing "_" if exists as it will be included in the suffix anyway |
| 94 | + if attr_name.endswith("_"): |
| 95 | + attr_name = attr_name[:-1] |
| 96 | + idx = 1 |
| 97 | + while f"{attr_name}_{idx}" in used_identifiers: |
| 98 | + idx += 1 |
| 99 | + attr_name = f"{attr_name}_{idx}" |
| 100 | + used_identifiers.add(attr_name) |
| 101 | + alias = col_name if attr_name != col_name else None |
| 102 | + col_code = _dtype_to_column_code(series, alias=alias) |
| 103 | + lines.append(f" {attr_name} = {col_code}") |
| 104 | + |
| 105 | + return "\n".join(lines) |
| 106 | + |
| 107 | + |
| 108 | +def _make_valid_identifier(name: str, col_index: int) -> str: |
| 109 | + """Convert a string to a valid Python identifier.""" |
| 110 | + # Replace invalid characters with underscores |
| 111 | + valid_identifier = re.sub(r"[^a-zA-Z0-9_]", "_", name) |
| 112 | + |
| 113 | + # Handle empty name or name with only special characters ones with simple "_" |
| 114 | + if set(valid_identifier).issubset({"_"}): |
| 115 | + return f"column_{col_index}" |
| 116 | + # Ensure it doesn't start with a digit |
| 117 | + if valid_identifier[0].isdigit(): |
| 118 | + return "_" + valid_identifier |
| 119 | + if keyword.iskeyword(valid_identifier): |
| 120 | + return valid_identifier + "_" |
| 121 | + return valid_identifier |
| 122 | + |
| 123 | + |
| 124 | +def _get_dtype_args(dtype: pl.DataType, series: pl.Series) -> list[str]: |
| 125 | + """Get extra arguments for parameterized types.""" |
| 126 | + if isinstance(dtype, pl.Datetime): |
| 127 | + args = [] |
| 128 | + if dtype.time_zone is not None: |
| 129 | + args.append(f'time_zone="{dtype.time_zone}"') |
| 130 | + if dtype.time_unit != "us": |
| 131 | + args.append(f'time_unit="{dtype.time_unit}"') |
| 132 | + return args |
| 133 | + |
| 134 | + if isinstance(dtype, pl.Duration): |
| 135 | + if dtype.time_unit != "us": # us is the default |
| 136 | + return [f'time_unit="{dtype.time_unit}"'] |
| 137 | + |
| 138 | + if isinstance(dtype, pl.Decimal): |
| 139 | + args = [] |
| 140 | + if dtype.precision is not None: |
| 141 | + args.append(f"precision={dtype.precision}") |
| 142 | + if dtype.scale != 0: |
| 143 | + args.append(f"scale={dtype.scale}") |
| 144 | + return args |
| 145 | + |
| 146 | + if isinstance(dtype, pl.Enum): |
| 147 | + return [repr(dtype.categories.to_list())] |
| 148 | + |
| 149 | + if isinstance(dtype, pl.List): |
| 150 | + return [_dtype_to_column_code(series.explode())] |
| 151 | + |
| 152 | + if isinstance(dtype, pl.Array): |
| 153 | + return [_dtype_to_column_code(series.explode()), f"shape={dtype.size}"] |
| 154 | + |
| 155 | + if isinstance(dtype, pl.Struct): |
| 156 | + fields_parts = [] |
| 157 | + for field in dtype.fields: |
| 158 | + field_code = _dtype_to_column_code(series.struct.field(field.name)) |
| 159 | + fields_parts.append(f'"{field.name}": {field_code}') |
| 160 | + return ["{" + ", ".join(fields_parts) + "}"] |
| 161 | + |
| 162 | + return [] |
| 163 | + |
| 164 | + |
| 165 | +def _format_args(*args: str, nullable: bool = False, alias: str | None = None) -> str: |
| 166 | + """Format arguments for column constructor.""" |
| 167 | + all_args = list(args) |
| 168 | + if nullable: |
| 169 | + all_args.append("nullable=True") |
| 170 | + if alias is not None: |
| 171 | + all_args.append(f'alias="{alias}"') |
| 172 | + return ", ".join(all_args) |
| 173 | + |
| 174 | + |
| 175 | +def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str: |
| 176 | + """Convert a Polars Series to dataframely column constructor code.""" |
| 177 | + dtype = series.dtype |
| 178 | + nullable = series.null_count() > 0 |
| 179 | + dy_name = _POLARS_DTYPE_MAP.get(type(dtype)) |
| 180 | + |
| 181 | + if dy_name is None: |
| 182 | + return f"dy.Any({_format_args(alias=alias)})" |
| 183 | + |
| 184 | + args = _get_dtype_args(dtype, series) |
| 185 | + return f"dy.{dy_name}({_format_args(*args, nullable=nullable, alias=alias)})" |
0 commit comments