Add basic Flux integration tests using Enzyme#2968
Add basic Flux integration tests using Enzyme#2968gamila-wisam wants to merge 2 commits intoEnzymeAD:mainfrom
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/test/integration/Flux/runtests.jl b/test/integration/Flux/runtests.jl
index 6158379d..fab13466 100644
--- a/test/integration/Flux/runtests.jl
+++ b/test/integration/Flux/runtests.jl
@@ -33,7 +33,7 @@ function test_enzyme_gradients(model, x, ps, st)
dx_zygote, dps_zygote = compute_zygote_gradient(model, x, ps, st)
@test check_approx(dx, dx_zygote; atol = 1.0f-3, rtol = 1.0f-3)
- @test check_approx(dps, dps_zygote; atol = 1.0f-3, rtol = 1.0f-3)
+ return @test check_approx(dps, dps_zygote; atol = 1.0f-3, rtol = 1.0f-3)
end
# small list of models to test |
|
The CI failure seems to be in SciML tests and is unrelated to these Flux integration tests. The new tests run successfully locally using : |
| (Dense(2, 3), randn(Float32, 2, 4)), | ||
|
|
||
| # small Chain | ||
| (Chain(Dense(2, 4, relu), Dense(4, 2)), randn(Float32, 2, 3)), |
There was a problem hiding this comment.
cc @CarloLucibello were there more models you wanted to test here, I know the flux ones have a bigger list iirc
| # compare Enzyme gradients with Zygote gradients | ||
| function test_enzyme_gradients(model, x, ps, st) | ||
| dx, dps = compute_enzyme_gradient(model, x, ps, st) | ||
| dx_zygote, dps_zygote = compute_zygote_gradient(model, x, ps, st) |
There was a problem hiding this comment.
@gamila-wisam with zygote broken on 1.12, can you have this test against something other than zygote [otherwise we can't compare on 1.12+]
There was a problem hiding this comment.
Sure, what about finite-differences gradients?
There was a problem hiding this comment.
sure [as long as the models aren't so large that the time would be reasonable]
There was a problem hiding this comment.
Yes I've considered that, I will try to ensure that runtime stays reasonable
| ps = Flux.trainable(model) | ||
| st = nothing | ||
|
|
||
| # run the gradient test | ||
| test_enzyme_gradients(model, x, ps, st) |
There was a problem hiding this comment.
This approach won't work.
Flux is not Lux where the parameters are kept separate from the model.
The gradient should be taken with respect to the model itself.
Summary
This PR adds a small set of integration tests for Flux models using Enzyme.jl, comparing
Enzyme gradients against Zygote gradients.
Details
check_approxto compare Enzyme vs Zygote gradients.Flux.trainable(replacing deprecatedFlux.params).Testing
include("test/integration/Flux/runtests.jl").Motivation
Adds coverage for Flux models, ensuring Enzyme works correctly with common Flux layers.
Related issue
References FluxML/Flux.jl#2644