Skip to content

Commit b388066

Browse files
committed
Filter out rules that require config
1 parent 5468d28 commit b388066

2 files changed

Lines changed: 35 additions & 3 deletions

File tree

src/ruleset_loading.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,22 @@ clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind))
5252
_rule_list(frule | rrule)
5353
5454
Returns a list of all the methods of the currently defined rules of the given kind.
55-
Excluding the fallback rule that returns `nothing` for every input.
55+
Excluding the fallback rule that returns `nothing` for every input;
56+
and excluding rules that require a particular `RuleConfig`.
5657
"""
57-
function _rule_list end
58-
_rule_list(rule_kind) = (m for m in methods(rule_kind) if !_is_fallback(rule_kind, m))
58+
function _rule_list(rule_kind)
59+
return Iterators.filter(methods(rule_kind)) do m
60+
!_is_fallback(rule_kind, m) && !_requires_config(m)
61+
end
62+
end
5963

6064
"check if this is the fallback-frule/rrule that always returns `nothing`"
6165
_is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Vararg{Any}}
6266
_is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}}
6367

68+
"check if this rule requires a particular configuation (`RuleConfig`)"
69+
_requires_config(m::Method) = m.sig.parameters[2] <: RuleConfig
70+
6471
const LAST_REFRESH_RRULE = Ref(0)
6572
const LAST_REFRESH_FRULE = Ref(0)
6673
last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE

test/ruleset_loading.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,29 @@
7575
@test _is_fallback(rrule, first(methods(rrule, (Nothing,))))
7676
@test _is_fallback(frule, first(methods(frule, (Tuple{}, Nothing,))))
7777
end
78+
79+
@test "_rule_list" begin
80+
_rule_list = ChainRulesOverloadGeneration._rule_list
81+
@testset "should not have frules that need RuleConfig" begin
82+
old_frule_list = collect(_rule_list(frule))
83+
function ChainRulesCore.frule(
84+
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, dargs, sum, f, xs
85+
)
86+
return 1.0, 1.0 # this will not be call so return doesn't matter
87+
end
88+
# New rule should not have appeared
89+
@test collect(_rule_list(frule)) == old_frule_list
90+
end
91+
92+
@testset "should not have rrules that need RuleConfig" begin
93+
old_rrule_list = collect(_rule_list(rrule))
94+
function ChainRulesCore.rrule(
95+
::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs
96+
)
97+
return 1.0, x->(x,x,x) # this will not be call so return doesn't matter
98+
end
99+
# New rule should not have appeared
100+
@test collect(_rule_list(rrule)) == old_rrule_list
101+
end
102+
end
78103
end

0 commit comments

Comments
 (0)