66import sys
77from collections import defaultdict
88from collections .abc import Callable
9- from typing import Any
9+ from typing import Any , Literal
1010
1111import polars as pl
1212
@@ -99,7 +99,9 @@ class RuleFactory:
9999 """Factory class for rules created within schemas."""
100100
101101 def __init__ (
102- self , validation_fn : Callable [[Any ], pl .Expr ], group_columns : list [str ] | None
102+ self ,
103+ validation_fn : Callable [[Any ], pl .Expr ],
104+ group_columns : list [str ] | Literal ["primary_key" ] | None ,
103105 ) -> None :
104106 self .validation_fn = validation_fn
105107 self .group_columns = group_columns
@@ -116,16 +118,28 @@ def from_rule(cls, rule: Rule) -> Self:
116118
117119 def make (self , schema : Any ) -> Rule :
118120 """Create a new rule from this factory."""
119- if self .group_columns is not None :
121+ group_columns : list [str ] | None
122+ if self .group_columns == "primary_key" :
123+ from dataframely .exc import ImplementationError
124+
125+ group_columns = schema .primary_key ()
126+ if not group_columns :
127+ raise ImplementationError (
128+ "Rule uses `group_by='primary_key'` but the schema has no"
129+ " primary key."
130+ )
131+ else :
132+ group_columns = self .group_columns
133+ if group_columns is not None :
120134 return GroupRule (
121135 expr = lambda : self .validation_fn (schema ),
122- group_columns = self . group_columns ,
136+ group_columns = group_columns ,
123137 )
124138 return Rule (expr = lambda : self .validation_fn (schema ))
125139
126140
127141def rule (
128- * , group_by : list [str ] | None = None
142+ * , group_by : list [str ] | Literal [ "primary_key" ] | None = None
129143) -> Callable [[ValidationFunction ], RuleFactory ]:
130144 """Mark a function as a rule to evaluate during validation.
131145
@@ -147,7 +161,10 @@ def rule(
147161 group_by: An optional list of columns to group by for rules operating on groups
148162 of rows. If this list is provided, the returned expression must return a
149163 single boolean value, i.e. some kind of aggregation function must be used
150- (e.g. `sum`, `any`, ...).
164+ (e.g. `sum`, `any`, ...). Pass ``"primary_key"`` to dynamically resolve to
165+ the schema's primary key columns at class creation time. This is useful for
166+ defining rules in mixin classes where the primary key is not known at
167+ definition time.
151168
152169 Note:
153170 You'll need to explicitly handle `null` values in your columns when defining
0 commit comments