Skip to content

Commit 00d12d1

Browse files
committed
Basic files created
1 parent 6559e32 commit 00d12d1

10 files changed

Lines changed: 58 additions & 21 deletions

File tree

.github/workflows/IntegrationTest.yml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@ jobs:
1515
julia-version: [1.5]
1616
os: [ubuntu-latest]
1717
package:
18-
- {user: JuliaDiff, repo: ChainRules.jl}
19-
- {user: JuliaMath, repo: SpecialFunctions.jl}
20-
- {user: invenia, repo: BlockDiagonals.jl}
21-
- {user: invenia, repo: PDMatsExtras.jl}
22-
- {user: chrisbrahms, repo: Hankel.jl}
23-
- {user: SciML, repo: DiffEqBase.jl}
24-
- {user: dfdx, repo: Yota.jl}
18+
# - {user: Invenia, repo: Nabla.jl}
2519

2620
steps:
2721
- uses: actions/checkout@v2
@@ -43,7 +37,7 @@ jobs:
4337
# force it to use this PR's version of the package
4438
Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
4539
Pkg.update()
46-
Pkg.test() # resolver may fail with test time deps
40+
Pkg.test() # resolver may fail with test time deps
4741
catch err
4842
err isa Pkg.Resolve.ResolverError || rethrow()
4943
# If we can't resolve that means this is incompatible by SemVer and this is fine

Project.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
name = "ChainRulesOverloadGeneration"
2+
uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
3+
version = "0.1.0"
4+
5+
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7+
8+
[compat]
9+
ChainRulesCore = "0.9"
10+
11+
[extras]
12+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
13+
14+
[targets]
15+
test = ["Test"]

README.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,5 @@
1212
[![](https://img.shields.io/badge/docs-master-blue.svg)](https://juliadiff.org/ChainRulesOverloadGeneration.jl/dev)
1313
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliadiff.org/ChainRulesOverloadGeneration.jl/stable)
1414

15-
The ChainRulesOverloadGeneration package provides a light-weight dependency for defining sensitivities for functions in your packages, without you needing to depend on ChainRules itself.
16-
17-
This will allow your package to be used with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl), which aims to provide a variety of common utilities that can be used by downstream automatic differentiation (AD) tools to define and execute forward-, reverse-, and mixed-mode primitives.
18-
19-
This package is a work in progress; PRs welcome!
15+
The ChainRulesOverloadGeneration package provides a suite of methods for using [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) rules in operator overloaded based AD systems.
16+
It tracks what rules are defined at any point in time, and lets you trigger functions to which can use `@eval` in order to define the matching operator overloads.

docs/Manifest.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
1111

1212
[[ChainRulesCore]]
1313
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
14-
path = ".."
14+
git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb"
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1616
version = "0.9.44"
1717

18+
[[ChainRulesOverloadGeneration]]
19+
deps = ["ChainRulesCore"]
20+
path = ".."
21+
uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
22+
version = "0.1.0"
23+
1824
[[Compat]]
1925
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
20-
git-tree-sha1 = "0900bc19193b8e672d9cd477e6cd92d9e7c02f99"
26+
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
2127
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
22-
version = "3.29.0"
28+
version = "3.30.0"
2329

2430
[[Dates]]
2531
deps = ["Printf"]

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3+
ChainRulesOverloadGeneration = "f51149dc-2911-5acf-81fc-2076a2a81d4f"
34
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
45
DocThemeIndigo = "8bac0ac5-51bf-41f9-885e-2bf1ac2bec5f"
56
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DocThemeIndigo
44
using Markdown
55

66
DocMeta.setdocmeta!(
7-
ChainRulesCore,
7+
ChainRulesOverloadGeneration,
88
:DocTestSetup,
99
quote
1010
using Random
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using ChainRulesCore
2+
3+
export on_new_rule, refresh_rules
4+
5+
include("precompile.jl")
6+
include("ruleset_loading.jl")
7+
8+
function __init__()
9+
# Need to refresh rules when a package is loaded
10+
push!(Base.package_callbacks, _package_hook)
11+
end
12+
13+
end

src/ruleset_loading.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
# Infastructure to support generating overloads from rules.
22
_package_hook(::Base.PkgId) = refresh_rules()
3-
function __init__()
4-
# Need to refresh rules when a package is loaded
5-
push!(Base.package_callbacks, _package_hook)
6-
end
73

84
# Holds all the hook functions that are invokes when a new rule is defined
95
const RRULE_DEFINITION_HOOKS = Function[]

test/demos/forwarddiffzero.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"The simplest viable forward mode a AD, only supports `Float64`"
22
module ForwardDiffZero
33
using ChainRulesCore
4+
using ChainRulesOverloadGeneration
45
using Test
56

67
#########################################

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Base.Broadcast: broadcastable
2+
using BenchmarkTools
3+
using ChainRulesCore
4+
using ChainRulesOverloadGeneration
5+
using Test
6+
7+
@testset "ChainRulesCore" begin
8+
include("ruleset_loading.jl")
9+
10+
@testset "demos" begin
11+
include("demos/forwarddiffzero.jl")
12+
include("demos/reversediffzero.jl")
13+
end
14+
end

0 commit comments

Comments
 (0)