Skip to content

Commit 2534bf7

Browse files
authored
feat: Allow to automatically run group rules on primary key (#300)
1 parent 1527718 commit 2534bf7

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

dataframely/_rule.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sys
77
from collections import defaultdict
88
from collections.abc import Callable
9-
from typing import Any
9+
from typing import Any, Literal
1010

1111
import polars as pl
1212

@@ -99,7 +99,9 @@ class RuleFactory:
9999
"""Factory class for rules created within schemas."""
100100

101101
def __init__(
102-
self, validation_fn: Callable[[Any], pl.Expr], group_columns: list[str] | None
102+
self,
103+
validation_fn: Callable[[Any], pl.Expr],
104+
group_columns: list[str] | Literal["primary_key"] | None,
103105
) -> None:
104106
self.validation_fn = validation_fn
105107
self.group_columns = group_columns
@@ -116,16 +118,28 @@ def from_rule(cls, rule: Rule) -> Self:
116118

117119
def make(self, schema: Any) -> Rule:
118120
"""Create a new rule from this factory."""
119-
if self.group_columns is not None:
121+
group_columns: list[str] | None
122+
if self.group_columns == "primary_key":
123+
from dataframely.exc import ImplementationError
124+
125+
group_columns = schema.primary_key()
126+
if not group_columns:
127+
raise ImplementationError(
128+
"Rule uses `group_by='primary_key'` but the schema has no"
129+
" primary key."
130+
)
131+
else:
132+
group_columns = self.group_columns
133+
if group_columns is not None:
120134
return GroupRule(
121135
expr=lambda: self.validation_fn(schema),
122-
group_columns=self.group_columns,
136+
group_columns=group_columns,
123137
)
124138
return Rule(expr=lambda: self.validation_fn(schema))
125139

126140

127141
def rule(
128-
*, group_by: list[str] | None = None
142+
*, group_by: list[str] | Literal["primary_key"] | None = None
129143
) -> Callable[[ValidationFunction], RuleFactory]:
130144
"""Mark a function as a rule to evaluate during validation.
131145
@@ -147,7 +161,10 @@ def rule(
147161
group_by: An optional list of columns to group by for rules operating on groups
148162
of rows. If this list is provided, the returned expression must return a
149163
single boolean value, i.e. some kind of aggregation function must be used
150-
(e.g. `sum`, `any`, ...).
164+
(e.g. `sum`, `any`, ...). Pass ``"primary_key"`` to dynamically resolve to
165+
the schema's primary key columns at class creation time. This is useful for
166+
defining rules in mixin classes where the primary key is not known at
167+
definition time.
151168
152169
Note:
153170
You'll need to explicitly handle `null` values in your columns when defining

tests/schema/test_rule_implementation.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,66 @@ def test_group_rule_group_by_error() -> None:
2929
)
3030

3131

32+
def test_group_rule_primary_key_single() -> None:
33+
class MySchema(dy.Schema):
34+
a = dy.Int64(primary_key=True)
35+
b = dy.Int64()
36+
37+
@dy.rule(group_by="primary_key")
38+
def b_positive(cls) -> pl.Expr:
39+
return (pl.col("b") > 0).all()
40+
41+
rules = MySchema._schema_validation_rules()
42+
assert isinstance(rules["b_positive"], GroupRule)
43+
assert rules["b_positive"].group_columns == ["a"]
44+
45+
46+
def test_group_rule_primary_key_composite() -> None:
47+
class MySchema(dy.Schema):
48+
a = dy.Int64(primary_key=True)
49+
b = dy.Int64(primary_key=True)
50+
c = dy.Int64()
51+
52+
@dy.rule(group_by="primary_key")
53+
def c_positive(cls) -> pl.Expr:
54+
return (pl.col("c") > 0).all()
55+
56+
rules = MySchema._schema_validation_rules()
57+
assert isinstance(rules["c_positive"], GroupRule)
58+
assert sorted(rules["c_positive"].group_columns) == ["a", "b"]
59+
60+
61+
def test_group_rule_primary_key_no_pk() -> None:
62+
with pytest.raises(
63+
ImplementationError,
64+
match=r"group_by='primary_key'.*no primary key",
65+
):
66+
67+
class MySchema(dy.Schema):
68+
a = dy.Int64()
69+
70+
@dy.rule(group_by="primary_key")
71+
def a_positive(cls) -> pl.Expr:
72+
return (pl.col("a") > 0).all()
73+
74+
75+
def test_group_rule_primary_key_mixin() -> None:
76+
class MyMixin:
77+
id = dy.Int64(primary_key=True)
78+
value = dy.Int64()
79+
80+
@dy.rule(group_by="primary_key")
81+
def value_positive(cls) -> pl.Expr:
82+
return (pl.col("value") > 0).all()
83+
84+
class MySchema(MyMixin, dy.Schema):
85+
other_id = dy.Int64(primary_key=True)
86+
87+
rules = MySchema._schema_validation_rules()
88+
assert isinstance(rules["value_positive"], GroupRule)
89+
assert rules["value_positive"].group_columns == ["id", "other_id"]
90+
91+
3292
def test_rule_column_overlap_error() -> None:
3393
with pytest.raises(
3494
ImplementationError,

0 commit comments

Comments
 (0)