Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .JuliaFormatter.toml

This file was deleted.

19 changes: 19 additions & 0 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: format-check

on:
push:
branches:
- 'master'
- 'main'
- 'release-'
tags: '*'
pull_request:

jobs:
runic:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: fredrikekre/runic-action@v1
with:
version: '1'
16 changes: 8 additions & 8 deletions ext/ADTypesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@ module ADTypesChainRulesCoreExt

using ADTypes: ADTypes, AutoChainRules
using ChainRulesCore: HasForwardsMode, HasReverseMode,
NoForwardsMode, NoReverseMode,
RuleConfig
NoForwardsMode, NoReverseMode,
RuleConfig

# see https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/superpowers/ruleconfig.html

function ADTypes.mode(::AutoChainRules{RC}) where {
RC <: RuleConfig{>:HasForwardsMode}
}
RC <: RuleConfig{>:HasForwardsMode},
}
return ADTypes.ForwardMode()
end

function ADTypes.mode(::AutoChainRules{RC}) where {
RC <: RuleConfig{>:HasReverseMode}
}
RC <: RuleConfig{>:HasReverseMode},
}
return ADTypes.ReverseMode()
end

function ADTypes.mode(::AutoChainRules{RC}) where {
RC <: RuleConfig{>:Union{HasForwardsMode, HasReverseMode}}
}
RC <: RuleConfig{>:Union{HasForwardsMode, HasReverseMode}},
}
# more specific than the previous two
return ADTypes.ForwardOrReverseMode()
end
Expand Down
40 changes: 20 additions & 20 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ include("symbols.jl")
# Automatic Differentiation
export AbstractADType
export AutoChainRules,
AutoDiffractor,
AutoEnzyme,
AutoFastDifferentiation,
AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoGTPSA,
AutoModelingToolkit,
AutoMooncake,
AutoMooncakeForward,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
AutoTapir,
AutoTaylorDiff,
AutoTracker,
AutoZygote,
NoAutoDiff,
NoAutoDiffSelectedError,
AutoReactant
AutoDiffractor,
AutoEnzyme,
AutoFastDifferentiation,
AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoGTPSA,
AutoModelingToolkit,
AutoMooncake,
AutoMooncakeForward,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
AutoTapir,
AutoTaylorDiff,
AutoTracker,
AutoZygote,
NoAutoDiff,
NoAutoDiffSelectedError,
AutoReactant
@public AbstractMode
@public ForwardMode, ReverseMode, ForwardOrReverseMode, SymbolicMode
@public mode
Expand Down
6 changes: 3 additions & 3 deletions src/compat.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Backward compatibility with `public` keyword, as suggested in
# Backward compatibility with `public` keyword, as suggested in
# https://discourse.julialang.org/t/is-compat-jl-worth-it-for-the-public-keyword/119041/22
macro public(ex)
if VERSION >= v"1.11.0-DEV.469"
return if VERSION >= v"1.11.0-DEV.469"
args = ex isa Symbol ? (ex,) :
Base.isexpr(ex, :tuple) ? ex.args : error("Failed to mark $ex as public")
Base.isexpr(ex, :tuple) ? ex.args : error("Failed to mark $ex as public")
esc(Expr(:public, args...))
else
nothing
Expand Down
49 changes: 28 additions & 21 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ end
mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension

function Base.show(io::IO, backend::AutoChainRules)
print(io, AutoChainRules, "(ruleconfig=", repr(backend.ruleconfig; context = io), ")")
return print(io, AutoChainRules, "(ruleconfig=", repr(backend.ruleconfig; context = io), ")")
end

"""
Expand Down Expand Up @@ -68,7 +68,8 @@ struct AutoEnzyme{M, A} <: AbstractADType
end

function AutoEnzyme(;
mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A}
mode::M = nothing, function_annotation::Type{A} = Nothing
) where {M, A}
return AutoEnzyme{M, A}(mode)
end

Expand All @@ -79,7 +80,7 @@ function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A}
!isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io))
!isnothing(backend.mode) && !(A <: Nothing) && print(io, ", ")
!(A <: Nothing) && print(io, "function_annotation=", repr(A; context = io))
print(io, ")")
return print(io, ")")
end


Expand All @@ -101,12 +102,13 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
+ an [`AutoEnzyme`](@ref) object if a specific mode is required
+ `nothing` to choose the best mode automatically
"""
struct AutoReactant{M<:AutoEnzyme} <: AbstractADType
struct AutoReactant{M <: AutoEnzyme} <: AbstractADType
mode::M
end

function AutoReactant(;
mode::Union{AutoEnzyme,Nothing} = nothing)
mode::Union{AutoEnzyme, Nothing} = nothing
)
if mode === nothing
mode = AutoEnzyme()
end
Expand All @@ -118,7 +120,7 @@ mode(r::AutoReactant) = mode(r.mode)
function Base.show(io::IO, backend::AutoReactant)
print(io, AutoReactant, "(")
print(io, "mode=", repr(backend.mode; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -184,7 +186,7 @@ function Base.show(io::IO, backend::AutoFiniteDiff)
print(io, "absstep=", repr(backend.absstep; context = io), ", ")
backend.dir != true &&
print(io, "dir=", repr(backend.dir; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand All @@ -209,7 +211,7 @@ end
mode(::AutoFiniteDifferences) = ForwardMode()

function Base.show(io::IO, backend::AutoFiniteDifferences)
print(io, AutoFiniteDifferences, "(fdm=", repr(backend.fdm; context = io), ")")
return print(io, AutoFiniteDifferences, "(fdm=", repr(backend.fdm; context = io), ")")
end

"""
Expand Down Expand Up @@ -245,10 +247,12 @@ mode(::AutoForwardDiff) = ForwardMode()

function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize}
print(io, AutoForwardDiff, "(")
chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io),
(backend.tag !== nothing ? ", " : ""))
chunksize !== nothing && print(
io, "chunksize=", repr(chunksize; context = io),
(backend.tag !== nothing ? ", " : "")
)
backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -277,7 +281,7 @@ mode(::AutoTaylorDiff) = ForwardMode()
function Base.show(io::IO, ::AutoTaylorDiff{order}) where {order}
print(io, AutoTaylorDiff, "(")
print(io, "order=", repr(order; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -309,7 +313,7 @@ mode(::AutoGTPSA) = ForwardMode()
function Base.show(io::IO, backend::AutoGTPSA{D}) where {D}
print(io, AutoGTPSA, "(")
D != Nothing && print(io, "descriptor=", repr(backend.descriptor; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -343,7 +347,7 @@ function Base.show(io::IO, backend::AutoMooncake)
print(io, AutoMooncake, "(")
backend.config !== nothing &&
print(io, "config=", repr(backend.config; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -377,7 +381,7 @@ function Base.show(io::IO, backend::AutoMooncakeForward)
print(io, AutoMooncakeForward, "(")
backend.config !== nothing &&
print(io, "config=", repr(backend.config; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -415,10 +419,12 @@ mode(::AutoPolyesterForwardDiff) = ForwardMode()

function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize}
print(io, AutoPolyesterForwardDiff, "(")
chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io),
(backend.tag !== nothing ? ", " : ""))
chunksize !== nothing && print(
io, "chunksize=", repr(chunksize; context = io),
(backend.tag !== nothing ? ", " : "")
)
backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io))
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -466,7 +472,8 @@ function Base.getproperty(ad::AutoReverseDiff, s::Symbol)
if s === :compile
Base.depwarn(
"`ad.compile` where `ad` is `AutoReverseDiff` has been deprecated and will be removed in v2. Instead it is available as a compile-time constant as `AutoReverseDiff{true}` or `AutoReverseDiff{false}`.",
:getproperty)
:getproperty
)
end
return getfield(ad, s)
end
Expand All @@ -476,7 +483,7 @@ mode(::AutoReverseDiff) = ReverseMode()
function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile}
print(io, AutoReverseDiff, "(")
compile && print(io, "compile=true")
print(io, ")")
return print(io, ")")
end

"""
Expand Down Expand Up @@ -522,7 +529,7 @@ mode(::AutoTapir) = ReverseMode()
function Base.show(io::IO, backend::AutoTapir)
print(io, AutoTapir, "(")
!(backend.safe_mode) && print(io, "safe_mode=false")
print(io, ")")
return print(io, ")")
end

"""
Expand Down
13 changes: 9 additions & 4 deletions src/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

@deprecate AutoSparseForwardDiff(; kwargs...) AutoSparse(AutoForwardDiff(; kwargs...))

@deprecate AutoSparsePolyesterForwardDiff(; kwargs...) AutoSparse(AutoPolyesterForwardDiff(;
kwargs...))
@deprecate AutoSparsePolyesterForwardDiff(; kwargs...) AutoSparse(
AutoPolyesterForwardDiff(;
kwargs...
)
)

@deprecate AutoSparseReverseDiff(; kwargs...) AutoSparse(AutoReverseDiff(; kwargs...))

Expand All @@ -26,14 +29,16 @@ end
function AutoModelingToolkit(obj_sparse::Bool, cons_sparse::Bool)
Base.depwarn(
"`AutoModelingToolkit(obj_sparse, cons_sparse)` is deprecated, use `AutoSymbolics()` or `AutoSparse(AutoSymbolics())` instead.",
:AutoModelingToolkit; force = false)
:AutoModelingToolkit; force = false
)
return mtk_to_symbolics(obj_sparse, cons_sparse)
end

function AutoModelingToolkit(; obj_sparse::Bool = false, cons_sparse::Bool = false)
Base.depwarn(
"`AutoModelingToolkit(; obj_sparse, cons_sparse)` is deprecated, use `AutoSymbolics()` or `AutoSparse(AutoSymbolics())` instead.",
:AutoModelingToolkit; force = false)
:AutoModelingToolkit; force = false
)
return mtk_to_symbolics(obj_sparse, cons_sparse)
end

Expand Down
18 changes: 10 additions & 8 deletions src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ Wraps an ADTypes.jl object to deal with sparse Jacobians and Hessians.
)
"""
struct AutoSparse{
D <: AbstractADType,
S <: AbstractSparsityDetector,
C <: AbstractColoringAlgorithm
} <: AbstractADType
D <: AbstractADType,
S <: AbstractSparsityDetector,
C <: AbstractColoringAlgorithm,
} <: AbstractADType
dense_ad::D
sparsity_detector::S
coloring_algorithm::C
Expand All @@ -203,11 +203,12 @@ end
function AutoSparse(
dense_ad;
sparsity_detector = NoSparsityDetector(),
coloring_algorithm = NoColoringAlgorithm())
coloring_algorithm = NoColoringAlgorithm()
)
return AutoSparse{
typeof(dense_ad),
typeof(sparsity_detector),
typeof(coloring_algorithm)
typeof(coloring_algorithm),
}(dense_ad, sparsity_detector, coloring_algorithm)
end

Expand All @@ -218,9 +219,10 @@ function Base.show(io::IO, backend::AutoSparse)
end
if backend.coloring_algorithm != NoColoringAlgorithm()
print(
io, ", coloring_algorithm=", repr(backend.coloring_algorithm, context = io))
io, ", coloring_algorithm=", repr(backend.coloring_algorithm, context = io)
)
end
print(io, ")")
return print(io, ")")
end

"""
Expand Down
11 changes: 7 additions & 4 deletions src/symbols.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ ADTypes.AutoZygote()
"""
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)

for backend in (:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation,
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :GTPSA, :Mooncake, :PolyesterForwardDiff,
:ReverseDiff, :Symbolics, :Tapir, :TaylorDiff, :Tracker, :Zygote)
for backend in (
:ChainRules, :Diffractor, :Enzyme, :Reactant, :FastDifferentiation,
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :GTPSA, :Mooncake, :PolyesterForwardDiff,
:ReverseDiff, :Symbolics, :Tapir, :TaylorDiff, :Tracker, :Zygote,
)
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(
args...; kws...)
args...; kws...
)
end

Auto(::Nothing) = NoAutoDiff()
Loading
Loading