The following came up in SpeedyWeather/SpeedyWeather.jl#1051
Here I condensed it down to a MWE with some Claude help. There generally seems to be an issue differentiating through kernels based on the way they are configured with KA.
using KernelAbstractions
using Enzyme
@kernel inbounds = true function scale_kernel!(out, x, c)
I = @index(Global, Linear)
out[I] = x[I] * c
end
n = 16
x = rand(Float64, n)
out = zeros(Float64, n)
dout = ones(Float64, n)
dx = zeros(Float64, n)
# ── Case A: 2-arg kernel config; WORKING ──────────────────────
function scale_dynamic!(out, x; c = 2.0)
backend = KernelAbstractions.CPU()
kernel = scale_kernel!(backend, 64) # 2-arg: workgroup only
kernel(out, x, c, ndrange = length(out)) # ndrange passed at launch time
return out
end
try
autodiff(Reverse, scale_dynamic!, Const, Duplicated(out, dout), Duplicated(x, dx))
println(" OK dx[1] = ", dx[1])
catch e
println(" ERROR: ", sprint(showerror, e))
end
# ── Case B: 3-arg kernel config; NON-WORKING ───────────────────────
# Matches SpeedyWeather's configure_kernel: loop = kernel!(device(arch), workgroup, worksize)
fill!(out, 0); fill!(dout, 1); fill!(dx, 0)
function scale_static!(out, x; c = 2.0)
backend = KernelAbstractions.CPU()
workgroup = 16
worksize = (length(out),)
kernel = scale_kernel!(backend, workgroup, worksize) # 3-arg: static worksize
kernel(out, x, c) # no ndrange at launch
return out
end
try
autodiff(Reverse, scale_static!, Const, Duplicated(out, dout), Duplicated(x, dx))
println(" OK dx[1] = ", dx[1])
catch e
println(" ERROR: ", sprint(showerror, e))
end
The following came up in SpeedyWeather/SpeedyWeather.jl#1051
Here I condensed it down to a MWE with some Claude help. There generally seems to be an issue differentiating through kernels based on the way they are configured with KA.