Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/integration/Flux/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using Enzyme
using Flux
using Zygote
using Test
using NNlib
using StableRNGs
using Random

# generic loss function for any Flux model
generic_loss_function(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))

# compute gradients using Enzyme
function compute_enzyme_gradient(model, x, ps, st)
return Enzyme.gradient(
Enzyme.set_runtime_activity(Reverse),
generic_loss_function,
Const(model),
x,
ps,
Const(st),
)[2:3]
end

# compute gradients using Zygote
function compute_zygote_gradient(model, x, ps, st)
_, dx, dps, _ = Zygote.gradient(generic_loss_function, model, x, ps, st)
return dx, dps
end

# compare Enzyme gradients with Zygote gradients
function test_enzyme_gradients(model, x, ps, st)
dx, dps = compute_enzyme_gradient(model, x, ps, st)
dx_zygote, dps_zygote = compute_zygote_gradient(model, x, ps, st)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gamila-wisam with zygote broken on 1.12, can you have this test against something other than zygote [otherwise we can't compare on 1.12+]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, what about finite-differences gradients?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure [as long as the models aren't so large that the time would be reasonable]

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I've considered that, I will try to ensure that runtime stays reasonable


@test check_approx(dx, dx_zygote; atol = 1.0f-3, rtol = 1.0f-3)
@test check_approx(dps, dps_zygote; atol = 1.0f-3, rtol = 1.0f-3)
end

# small list of models to test
const MODELS_LIST = [
# simple Dense layer
(Dense(2, 3), randn(Float32, 2, 4)),

# small Chain
(Chain(Dense(2, 4, relu), Dense(4, 2)), randn(Float32, 2, 3)),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @CarloLucibello were there more models you wanted to test here, I know the flux ones have a bigger list iirc


# simple Conv layer
(Conv((3, 3), 2 => 2), randn(Float32, 5, 5, 2, 1)),
]


@testset "Enzyme Flux Integration" begin
for (i, (model, x)) in enumerate(MODELS_LIST)
@testset "[$i] $(nameof(typeof(model)))" begin
# set up parameters and state
ps = Flux.trainable(model)
st = nothing

# run the gradient test
test_enzyme_gradients(model, x, ps, st)
Comment on lines +56 to +60
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach won't work.
Flux is not Lux where the parameters are kept separate from the model.
The gradient should be taken with respect to the model itself.

end
end
end
Loading