Skip to content

Commit 874da9f

Browse files
committed
make tests pass
1 parent 08f7bf7 commit 874da9f

5 files changed

Lines changed: 24 additions & 8 deletions

File tree

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
module ChainRulesOverloadGeneration
2+
13
using ChainRulesCore
24

35
export on_new_rule, refresh_rules
46

5-
include("precompile.jl")
67
include("ruleset_loading.jl")
8+
include("precompile.jl")
79

810
function __init__()
911
# Need to refresh rules when a package is loaded
1012
push!(Base.package_callbacks, _package_hook)
1113
end
1214

13-
end
15+
end # module

src/ruleset_loading.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ Returns a list of all the methods of the currently defined rules of the given ki
5555
Excluding the fallback rule that returns `nothing` for every input.
5656
"""
5757
function _rule_list end
58-
# The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them
59-
# TODO this needs to be changed to work now it is in it's own repo
60-
_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__)
58+
_rule_list(rule_kind) = (m for m in methods(rule_kind) if !_is_fallback(rule_kind, m))
6159

60+
"check if this is the fallback-frule/rrule that always returns `nothing`"
61+
_is_fallback(rule_kind, m::Method) = m.sig === Tuple{typeof(rule_kind), Any, Vararg{Any}}
6262

6363
const LAST_REFRESH_RRULE = Ref(0)
6464
const LAST_REFRESH_FRULE = Ref(0)

test/demos/forwarddiffzero.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
module ForwardDiffZero
33
using ChainRulesCore
44
using ChainRulesOverloadGeneration
5+
# resolve conflicts while this code exists in both.
6+
const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
7+
const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
8+
59
using Test
610

711
#########################################

test/demos/reversediffzero.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
module ReverseDiffZero
33
using ChainRulesCore
44
using ChainRulesOverloadGeneration
5+
# resolve conflicts while this code exists in both.
6+
const on_new_rule = ChainRulesOverloadGeneration.on_new_rule
7+
const refresh_rules = ChainRulesOverloadGeneration.refresh_rules
8+
59
using Test
610

711
#########################################

test/ruleset_loading.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
end
2323

2424
@testset "# Make sure nothing happens anymore once we clear the hooks" begin
25-
ChainRulesCore.clear_new_rule_hooks!(frule)
26-
ChainRulesCore.clear_new_rule_hooks!(rrule)
25+
ChainRulesOverloadGeneration.clear_new_rule_hooks!(frule)
26+
ChainRulesOverloadGeneration.clear_new_rule_hooks!(rrule)
2727

2828
old_frule_history = copy(frule_history)
2929
old_rrule_history = copy(rrule_history)
@@ -34,9 +34,9 @@
3434
@test old_rrule_history == rrule_history
3535
@test old_frule_history == frule_history
3636
end
37-
3837
end
3938

39+
4040
@testset "_primal_sig" begin
4141
_primal_sig = ChainRulesOverloadGeneration._primal_sig
4242
@testset "frule" begin
@@ -69,4 +69,10 @@
6969
)
7070
end
7171
end
72+
73+
@testset "_is_fallback" begin
74+
_is_fallback = ChainRulesOverloadGeneration._is_fallback
75+
@test _is_fallback(rrule, only(methods(rrule, (Nothing,))))
76+
@test _is_fallback(frule, only(methods(frule, (Nothing,))))
77+
end
7278
end

0 commit comments

Comments
 (0)