diff --git a/src/Enzyme.jl b/src/Enzyme.jl index fcf75832f6..7061339431 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1314,13 +1314,21 @@ import .Compiler: remove_innerty, UnknownTapeType # If the tape is not cached, compile it if obj === nothing - Compiler.JuliaContext() do ctx + ts_ctx = Compiler.JuliaContext() + ctx = Compiler.context(ts_ctx) + Compiler.activate(ctx) + try _, meta = GPUCompiler.compile(:llvm, job) obj = meta.TapeType tape_cache[key] = obj + obj + finally + Compiler.deactivate(ctx) + Compiler.dispose(ts_ctx) end + else + obj end - obj finally unlock(tape_cache_lock) end diff --git a/src/compiler.jl b/src/compiler.jl index 4d73816804..bab631053a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1180,6 +1180,10 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV return custom, state end +const DumpPreNestedCheck = Ref(false) +const DumpPreNestedOpt = Ref(false) +const DumpPostNestedOpt = Ref(false) + function nested_codegen!( mode::API.CDerivativeMode, mod::LLVM.Module, @@ -1219,8 +1223,16 @@ function nested_codegen!( API.AddPreserveNVVMPass!(pm, true) #=Begin=# LLVM.run!(pm, otherMod) end + + if DumpPreNestedCheck[] + API.EnzymeDumpModuleRef(otherMod.ref) + end check_ir(interp, job, otherMod) + + if DumpPreNestedOpt[] + API.EnzymeDumpModuleRef(otherMod.ref) + end # Skipped inline of blas @@ -1229,6 +1241,11 @@ function nested_codegen!( # Apply first stage of optimization's so that this module is at the same stage as `mod` optimize!(otherMod, JIT.get_tm()) + + if DumpPostNestedOpt[] + API.EnzymeDumpModuleRef(otherMod.ref) + end + # 4) Link the corresponding module LLVM.link!(mod, otherMod) # 5) Call the function @@ -2228,6 +2245,7 @@ end include("rules/activityrules.jl") const DumpPreEnzyme = Ref(false) +const DumpPostEnzyme = Ref(false) const DumpPostWrap = Ref(false) function enzyme!( @@ -2576,6 +2594,9 @@ function enzyme!( adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname] augmented_primalf = augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname] + if DumpPostEnzyme[] + API.EnzymeDumpModuleRef(mod.ref) + end return adjointf, augmented_primalf, TapeType end @@ -5755,6 +5776,7 @@ function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module return CompileResult(adjoint_ptr, primal_ptr, TapeType, edges) end +const DumpPrePostOpt = Ref(false) const DumpPostOpt = Ref(false) # actual compilation @@ -5786,6 +5808,9 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri string(mod) end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI + if DumpPrePostOpt[] + API.EnzymeDumpModuleRef(mod.ref) + end post_optimze!(mod, JIT.get_tm()) if DumpPostOpt[] API.EnzymeDumpModuleRef(mod.ref) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 25b29cf080..6486d7375b 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -880,26 +880,44 @@ end @generated function same_sized(x::Tuple) result = :true prev = nothing + todo = Tuple{Expr, Type}[] for i in 1:length(x.parameters) - if x.parameters[i] <: Number + push!(todo, (:(x[$i]), x.parameters[i])) + end + while length(todo) != 0 + expr, ty = pop!(todo) + if ty <: Number || ty <: Base.RefValue continue end + if ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} + for i in 1:length(ty.parameters[4].parameters) + push!(todo, (:($expr.args[$i]), ty.parameters[4].parameters[i])) + end + continue + end + @assert ty <: AbstractArray if prev == nothing prev = quote - sz = size(x[$i]) + sz = size($expr) end continue end if result == :true result = quote - sz == size(x[$i]) + sz == size($expr) end else result = quote - $result && sz == size(x[$i]) + $result && sz == size($expr) end end end + if result == :true + return quote + Base.@_inline_meta + true + end + end return quote Base.@_inline_meta $prev @@ -907,29 +925,78 @@ end end end +@generated function first_array(x::Tuple) + result = :true + prev = nothing + todo = Tuple{Expr, Type}[] + for i in 1:length(x.parameters) + push!(todo, (:(x[$i]), x.parameters[i])) + end + while length(todo) != 0 + expr, ty = pop!(todo) + if ty <: Number || ty <: Base.RefValue + continue + end + if ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} + for i in 1:length(ty.parameters[4].parameters) + push!(todo, (:($expr.args[$i]), ty.parameters[4].parameters[i])) + end + continue + end + @assert ty <: AbstractArray + return quote + Base.@_inline_meta + $expr + end + end + return quote + Base.@_inline_meta + throw(AssertionError("No array")) + end +end + -Base.@propagate_inbounds overload_broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices -Base.@propagate_inbounds overload_broadcast_getindex(::Ref{Type{T}}, I) where {T} = T +Base.@propagate_inbounds @inline overload_broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices +Base.@propagate_inbounds @inline overload_broadcast_getindex(::Ref{Type{T}}, I) where {T} = T # Tuples are statically known to be singleton or vector-like -Base.@propagate_inbounds overload_broadcast_getindex(A::Tuple{Any}, I) = A[1] -Base.@propagate_inbounds overload_broadcast_getindex(A::Tuple, I) = error("unhandled") # A[I[1]] -Base.@propagate_inbounds overload_broadcast_getindex(A, I) = A[I] +Base.@propagate_inbounds @inline overload_broadcast_getindex(A::Tuple{Any}, I) = A[1] +Base.@propagate_inbounds @inline overload_broadcast_getindex(A::Tuple, I) = error("unhandled") # A[I[1]] +Base.@propagate_inbounds @generated function overload_broadcast_getindex(bc::Base.Broadcast.Broadcasted, I) + args = Expr[] + for i in 1:length(bc.parameters[4].parameters) + push!(args, Expr(:call, overload_broadcast_getindex, :(bc.args[$i]), :I)) + end + expr = Expr(:call, Base.Broadcast._broadcast_getindex_evalf, :(bc.f), args...) + return quote + Base.@_inline_meta + $expr + end +end + +Base.@propagate_inbounds @inline overload_broadcast_getindex(A, I) = @inbounds A[I] -@inline function override_bc_materialize(bc) +struct OverrideBCMaterialize{ElType} +end + +@inline function (::OverrideBCMaterialize{ElType})(bc) where ElType if bc.args isa Tuple{AbstractArray} && bc.f === Base.identity return copy(bc.args[1]) end - ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) - dest = similar(bc, ElType) - if all(isa_array_or_number, bc.args) && same_sized(bc.args) - @inbounds @simd for I in 1:length(bc) - val = Base.Broadcast._broadcast_getindex_evalf(bc.f, map(Base.Fix2(overload_broadcast_getindex, I), bc.args)...) + dest = @inline similar(bc, ElType) + if same_sized(bc.args) + # dest = @inline similar(first_array(bc.args), ElType) + @inbounds @simd for I in 1:length(bc) + val = overload_broadcast_getindex(bc, I) dest[I] = val end + return dest else - Base.copyto!(dest, bc) + # The existing code is rather slow for broadcast in practice: https://github.com/EnzymeAD/Enzyme.jl/issues/1434 + src = @inline Base.Broadcast.preprocess(nothing, bc) + idx = Base.eachindex(src) + @inline Enzyme.Compiler.Interpreter.lindex_v3(idx, dest, src) + return dest end - return dest end struct MultiOp{Position, NumUsed, F1, F2} @@ -958,12 +1025,28 @@ end end end -@inline function array_or_number(@nospecialize(Ty))::Bool - return Ty <: AbstractArray || Ty <: Number +@inline function bc_or_array_or_number_ty(@nospecialize(Ty::Type))::Bool + if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} + return all(bc_or_array_or_number_ty, Ty.parameters[4].parameters) + else + return Ty <: AbstractArray || Ty <: Number || Ty <: Base.RefValue + end end -@inline function isa_array_or_number(@nospecialize(x))::Bool - return x isa AbstractArray || x isa Number +@inline function has_array(@nospecialize(Ty::Type))::Bool + if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} + return any(has_array, Ty.parameters[4].parameters) + else + return Ty <: AbstractArray + end +end + +@generated function isa_bc_or_array_or_number(x)::Bool + res = bc_or_array_or_number_ty(x) + return quote + Base.@_inline_meta + $res + end end @inline function num_or_eltype(@nospecialize(Ty))::Type @@ -974,6 +1057,25 @@ end end end + +## Computation of inferred result type, for empty and concretely inferred cases only +ty_broadcast_getindex_eltype(interp, bc::Type{<:Base.Broadcast.Broadcasted}) = ty_combine_eltypes(interp, bc.parameters[3], (bc.parameters[4].parameters...,)) +ty_broadcast_getindex_eltype(interp, A) = eltype(A) # Tuple, Array, etc. + +ty_eltypes(interp, ::Tuple{}) = Tuple{} +ty_eltypes(interp, t::Tuple{Any}) = Iterators.TupleOrBottom(ty_broadcast_getindex_eltype(interp, t[1])) +ty_eltypes(interp, t::Tuple{Any,Any}) = Iterators.TupleOrBottom(ty_broadcast_getindex_eltype(interp, t[1]), ty_broadcast_getindex_eltype(interp, t[2])) +ty_eltypes(interp, t::Tuple) = (TT = ty_eltypes(interp, Base.tail(t)); TT === Union{} ? Union{} : Iterators.TupleOrBottom(ty_broadcast_getindex_eltype(interp, t[1]), TT.parameters...)) +# eltypes(t::Tuple) = Iterators.TupleOrBottom(ntuple(i -> _broadcast_getindex_eltype(t[i]), Val(length(t)))...) + +# Inferred eltype of result of broadcast(f, args...) +function ty_combine_eltypes(interp, f, args::Tuple) + argT = ty_eltypes(interp, args) + argT === Union{} && return Union{} + preprom = Core.Compiler._return_type(interp, Tuple{f, argT.parameters...}) + return Base.promote_typejoin_union(preprom) +end + function abstract_call_known( interp::EnzymeInterpreter{Handler}, @nospecialize(f), @@ -1012,20 +1114,24 @@ function abstract_call_known( if interp.broadcast_rewrite if f === Base.materialize && length(argtypes) == 2 bcty = widenconst(argtypes[2]) - if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && all(array_or_number, bcty.parameters[4].parameters) && any(Base.Fix2(Base.:<:, AbstractArray), bcty.parameters[4].parameters) + if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && bc_or_array_or_number_ty(bcty) && has_array(bcty) + ElType = ty_broadcast_getindex_eltype(interp, bcty) + if ElType !== Union{} && Base.isconcretetype(ElType) + fn2 = Enzyme.Compiler.Interpreter.OverrideBCMaterialize{ElType}() arginfo2 = ArgInfo( fargs isa Nothing ? nothing : - [:(Enzyme.Compiler.Interpreter.override_bc_materialize), fargs[2:end]...], - [Core.Const(Enzyme.Compiler.Interpreter.override_bc_materialize), argtypes[2:end]...], + [:(fn2), fargs[2:end]...], + [Core.Const(fn2), argtypes[2:end]...], ) return Base.@invoke abstract_call_known( interp::AbstractInterpreter, - Enzyme.Compiler.Interpreter.override_bc_materialize::Any, + fn2::Any, arginfo2::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int, ) + end end end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 8be7375c03..51895a096f 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -241,8 +241,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr if isa(inst, LLVM.CallInst) push!(calls, inst) # remove illegal invariant.load and jtbaa_const invariants - elseif isa(inst, LLVM.LoadInst) - + elseif isa(inst, LLVM.LoadInst) fn_got, _ = get_base_and_offset(operands(inst)[1]; offsetAllowed=false, inttoptr=false) fname = String(name(fn_got)) match_ = match(r"^jlplt_(.*)_\d+_got$", fname) @@ -468,6 +467,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr inst = pop!(calls) check_ir!(interp, job, errors, imported, inst, calls, mod) end + return errors end diff --git a/src/errors.jl b/src/errors.jl index 08c49dc408..83fc206fde 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -209,8 +209,10 @@ function Base.showerror(io::IO, ece::EnzymeMutabilityException) print(io, msg, '\n') end -struct EnzymeRuntimeActivityError <: EnzymeError +struct EnzymeRuntimeActivityError{MT,WT} <: EnzymeError msg::Cstring + mi::MT + world::WT end function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) @@ -239,10 +241,57 @@ function Base.showerror(io::IO, ece::EnzymeRuntimeActivityError) io, " b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.", ) + if ece.mi !== nothing + print(io, " Failure within method: ", ece.mi, "\n") + printstyled(io, "Hint"; bold = true, color = :cyan) + printstyled( + io, + ": catch this exception as `err` and call `code_typed(err)` to inspect the errornous code.\nIf you have Cthulu.jl loaded you can also use `code_typed(err; interactive = true)` to interactively introspect the code.\n"; + color = :cyan, + ) + end msg = Base.unsafe_string(ece.msg) print(io, msg, '\n') end +function InteractiveUtils.code_typed(ece::EnzymeRuntimeActivityError; interactive::Bool=false, kwargs...) + mi = ece.mi + if mi === nothing + throw(AssertionError("code_typed(::EnzymeRuntimeActivityError; interactive::Bool=false, kwargs...) not supported for error without mi")) + end + world = ece.world::UInt + mode = Enzyme.API.DEM_ReverseModeCombined + + CT = @static if VERSION >= v"1.11.0-DEV.1552" + EnzymeCacheToken( + typeof(DefaultCompilerTarget()), + false, + GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# + EnzymeCompilerParams, + world, + false, + true, + true + ) + else + Enzyme.Compiler.GLOBAL_REV_CACHE + end + + interp = Enzyme.Compiler.Interpreter.EnzymeInterpreter(CT, nothing, world, mode, true) + + sig = mi.specTypes # XXX: can we just use the method instance? + if interactive + # call Cthulhu without introducing a dependency on Cthulhu + mod = get(Base.loaded_modules, Cthulhu, nothing) + mod===nothing && error("Interactive code reflection requires Cthulhu; please install and load this package first.") + descend_code_typed = getfield(mod, :descend_code_typed) + descend_code_typed(sig; interp, kwargs...) + else + Base.code_typed_by_type(sig; interp, kwargs...) + end +end + + struct EnzymeNoTypeError <: EnzymeError msg::Cstring end @@ -890,7 +939,32 @@ end Base.show_backtrace(io, bt) end end - emit_error(b, nothing, msg2, EnzymeRuntimeActivityError) + + mi = nothing + world = nothing + + if isa(val, LLVM.Instruction) + f = LLVM.parent(LLVM.parent(val))::LLVM.Function + mi, rt = enzyme_custom_extract_mi( + f, + false, + ) #=error=# + world = enzyme_extract_world(f) + elseif isa(val, LLVM.Argument) + f = parent_scope(val)::LLVM.Function + mi, rt = enzyme_custom_extract_mi( + f, + false, + ) #=error=# + world = enzyme_extract_world(f) + end + mode = Enzyme.API.DEM_ReverseModeCombined + + if mi !== nothing + emit_error(b, nothing, (msg2, mi, world), EnzymeRuntimeActivityError{Core.MethodInstance, UInt}) + else + emit_error(b, nothing, msg2, EnzymeRuntimeActivityError{Nothing, Nothing}) + end return C_NULL elseif errtype == API.ET_GetIndexError @assert B != C_NULL diff --git a/src/jlrt.jl b/src/jlrt.jl index 9bf349e20a..fae6b71448 100644 --- a/src/jlrt.jl +++ b/src/jlrt.jl @@ -1074,26 +1074,32 @@ function emit_printf(B::LLVM.IRBuilder, string::String, v::LLVM.Value...) call!(B, LLVM.function_type(exc), exc, args) end -function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::Union{String, LLVM.Value}, @nospecialize(errty::Type) = EnzymeRuntimeException, @nospecialize(cond::Union{Nothing, LLVM.Value}) = nothing) +function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.Instruction}), string::Union{String, LLVM.Value, Tuple{String, Core.MethodInstance, UInt}}, @nospecialize(errty::Type) = EnzymeRuntimeException, @nospecialize(cond::Union{Nothing, LLVM.Value}) = nothing) curent_bb = position(B) fn = LLVM.parent(curent_bb) mod = LLVM.parent(fn) - if !isa(string, LLVM.Value) - string = globalstring_ptr!(B, string, "enz_exception") + stringv = string + if stringv isa Tuple + stringv = stringv[1] + end + if !isa(stringv, LLVM.Value) + stringv = globalstring_ptr!(B, stringv, "enz_exception") end ct = if occursin("ptx", LLVM.triple(mod)) || occursin("amdgcn", LLVM.triple(mod)) - + if string isa Tuple + errty = errty.name.wrapper{Nothing, Nothing} + end vt = LLVM.VoidType() ptr = convert(LLVMType, Ptr{Cvoid}) exc, _ = get_function!(mod, "gpu_report_exception", LLVM.FunctionType(vt, [ptr])) - string = ptrtoint!(B, string, ptr) + stringv = ptrtoint!(B, stringv, ptr) - call!(B, LLVM.function_type(exc), exc, [string]) + call!(B, LLVM.function_type(exc), exc, [stringv]) framefn, ft = get_function!( mod, @@ -1129,12 +1135,25 @@ function emit_error(B::LLVM.IRBuilder, @nospecialize(orig::Union{Nothing, LLVM.I call!(B, trap_ft, trap) else if cond !== nothing - emit_conditional_throw!(B, cond, errty, string) + if string isa Tuple + errty = errty.name.wrapper{Nothing, Nothing} + end + emit_conditional_throw!(B, cond, errty, stringv) else err = emit_allocobj!(B, errty) err2 = bitcast!(B, err, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), 10)) err2 = addrspacecast!(B, err2, LLVM.PointerType(LLVM.PointerType(LLVM.Int8Type()), Derived)) - store!(B, string, err2) + store!(B, stringv, err2) + if string isa Tuple + g1 = LLVM.inbounds_gep!(B, LLVM.PointerType(LLVM.Int8Type()), err2, [LLVM.ConstantInt(1)]) + ts = unsafe_to_llvm(B, string[2]) + g1 = LLVM.bitcast!(B, g1, LLVM.PointerType(value_type(ts), Derived)) + store!(B, ts, g1) + g2 = LLVM.inbounds_gep!(B, LLVM.PointerType(LLVM.Int8Type()), err2, [LLVM.ConstantInt(2)]) + ts = LLVM.ConstantInt(string[3]) + g2 = LLVM.bitcast!(B, g2, LLVM.PointerType(value_type(ts), Derived)) + store!(B, ts, g2) + end emit_jl_throw!( B, addrspacecast!(B, err, LLVM.PointerType(LLVM.StructType(LLVMType[]), 12)), diff --git a/test/runtests.jl b/test/runtests.jl index e22e00929d..4103fd1e2d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2033,6 +2033,9 @@ include("applyiter.jl") @test 0.5 ≈ Enzyme.autodiff(Reverse, dyn_mwe, Active, Active(1.0), Const((1, 2)))[1][1] end + +sinadd(x, y) = (sin.(x) .+ (y)) + @testset "broadcast" begin A = rand(10); B = rand(10); R = similar(A) dA = zero(A); dB = zero(B); dR = fill!(similar(R), 1) @@ -2051,6 +2054,10 @@ end dA = zero(A); dB = zero(B); dR = fill!(similar(A), 1) autodiff(Reverse, foo_bc!, Const, Duplicated(A, dR), Duplicated(transpose(A), transpose(dA)), Duplicated(B, dB)) + + # no runtime activity required + res = autodiff(Forward, sinadd, Duplicated([2.7], [4.2]), Const([.31]))[1] + @test [-3.7971029964716574] ≈ res end @@ -3769,22 +3776,13 @@ const objective3 = params -> mixture_loglikelihood3(params, data) -13.935687326484112, -38.00044665702692, 12.87712891527131] - @static if VERSION < v"1.11-" - @test expected ≈ Enzyme.gradient(Reverse, objective1, params0)[1] - else - # TODO broken should not throw - @test_throws Enzyme.Compiler.EnzymeRuntimeActivityError Enzyme.gradient(Reverse, objective1, params0)[1] - @test expected ≈ Enzyme.gradient(set_runtime_activity(Reverse), objective1, params0)[1] - end + @test expected ≈ Enzyme.gradient(Reverse, objective1, params0)[1] + @test expected ≈ Enzyme.gradient(set_runtime_activity(Reverse), objective1, params0)[1] + # objective2 fails from runtime activity requirements # @test expected ≈ Enzyme.gradient(Reverse, objective2, params0)[1] - @static if VERSION < v"1.11-" - @test expected ≈ Enzyme.gradient(Reverse, objective3, params0)[1] - else - # TODO broken should not throw - @test_throws Enzyme.Compiler.EnzymeRuntimeActivityError Enzyme.gradient(Reverse, objective3, params0)[1] - @test expected ≈ Enzyme.gradient(set_runtime_activity(Reverse), objective3, params0)[1] - end + @test expected ≈ Enzyme.gradient(Reverse, objective3, params0)[1] + @test expected ≈ Enzyme.gradient(set_runtime_activity(Reverse), objective3, params0)[1] end struct HarmonicAngle