Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1314,13 +1314,21 @@ import .Compiler: remove_innerty, UnknownTapeType
# If the tape is not cached, compile it
if obj === nothing

Compiler.JuliaContext() do ctx
ts_ctx = Compiler.JuliaContext()
ctx = Compiler.context(ts_ctx)
Compiler.activate(ctx)
try
_, meta = GPUCompiler.compile(:llvm, job)
obj = meta.TapeType
tape_cache[key] = obj
obj
finally
Compiler.deactivate(ctx)
Compiler.dispose(ts_ctx)
end
else
obj
end
obj
finally
unlock(tape_cache_lock)
end
Expand Down
25 changes: 25 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,10 @@ function set_module_types!(interp, mod::LLVM.Module, primalf::Union{Nothing, LLV
return custom, state
end

const DumpPreNestedCheck = Ref(false)
const DumpPreNestedOpt = Ref(false)
const DumpPostNestedOpt = Ref(false)

function nested_codegen!(
mode::API.CDerivativeMode,
mod::LLVM.Module,
Expand Down Expand Up @@ -1219,8 +1223,16 @@ function nested_codegen!(
API.AddPreserveNVVMPass!(pm, true) #=Begin=#
LLVM.run!(pm, otherMod)
end

if DumpPreNestedCheck[]
API.EnzymeDumpModuleRef(otherMod.ref)
end

check_ir(interp, job, otherMod)

if DumpPreNestedOpt[]
API.EnzymeDumpModuleRef(otherMod.ref)
end

# Skipped inline of blas

Expand All @@ -1229,6 +1241,11 @@ function nested_codegen!(

# Apply first stage of optimization's so that this module is at the same stage as `mod`
optimize!(otherMod, JIT.get_tm())

if DumpPostNestedOpt[]
API.EnzymeDumpModuleRef(otherMod.ref)
end

# 4) Link the corresponding module
LLVM.link!(mod, otherMod)
# 5) Call the function
Expand Down Expand Up @@ -2228,6 +2245,7 @@ end
include("rules/activityrules.jl")

const DumpPreEnzyme = Ref(false)
const DumpPostEnzyme = Ref(false)
const DumpPostWrap = Ref(false)

function enzyme!(
Expand Down Expand Up @@ -2576,6 +2594,9 @@ function enzyme!(
adjointf = adjointf == nothing ? nothing : functions(mod)[adjointfname]
augmented_primalf =
augmented_primalf == nothing ? nothing : functions(mod)[augmented_primalfname]
if DumpPostEnzyme[]
API.EnzymeDumpModuleRef(mod.ref)
end
return adjointf, augmented_primalf, TapeType
end

Expand Down Expand Up @@ -5755,6 +5776,7 @@ function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module
return CompileResult(adjoint_ptr, primal_ptr, TapeType, edges)
end

const DumpPrePostOpt = Ref(false)
const DumpPostOpt = Ref(false)

# actual compilation
Expand Down Expand Up @@ -5786,6 +5808,9 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, Stri
string(mod)
end
if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI
if DumpPrePostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
end
post_optimze!(mod, JIT.get_tm())
if DumpPostOpt[]
API.EnzymeDumpModuleRef(mod.ref)
Expand Down
156 changes: 131 additions & 25 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -880,56 +880,123 @@ end
@generated function same_sized(x::Tuple)
result = :true
prev = nothing
todo = Tuple{Expr, Type}[]
for i in 1:length(x.parameters)
if x.parameters[i] <: Number
push!(todo, (:(x[$i]), x.parameters[i]))
end
while length(todo) != 0
expr, ty = pop!(todo)
if ty <: Number || ty <: Base.RefValue
continue
end
if ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
Comment thread
wsmoses marked this conversation as resolved.
for i in 1:length(ty.parameters[4].parameters)
push!(todo, (:($expr.args[$i]), ty.parameters[4].parameters[i]))
end
continue
end
@assert ty <: AbstractArray
if prev == nothing
prev = quote
sz = size(x[$i])
sz = size($expr)
end
continue
end
if result == :true
result = quote
sz == size(x[$i])
sz == size($expr)
end
else
result = quote
$result && sz == size(x[$i])
$result && sz == size($expr)
end
end
end
if result == :true
return quote
Base.@_inline_meta
true
end
end
return quote
Base.@_inline_meta
$prev
return $result
end
end

@generated function first_array(x::Tuple)
result = :true
prev = nothing
todo = Tuple{Expr, Type}[]
for i in 1:length(x.parameters)
push!(todo, (:(x[$i]), x.parameters[i]))
end
while length(todo) != 0
expr, ty = pop!(todo)
if ty <: Number || ty <: Base.RefValue
continue
end
if ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
for i in 1:length(ty.parameters[4].parameters)
push!(todo, (:($expr.args[$i]), ty.parameters[4].parameters[i]))
end
continue
end
@assert ty <: AbstractArray
return quote
Base.@_inline_meta
$expr
end
end
return quote
Base.@_inline_meta
throw(AssertionError("No array"))
end
end


Base.@propagate_inbounds overload_broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices
Base.@propagate_inbounds overload_broadcast_getindex(::Ref{Type{T}}, I) where {T} = T
Base.@propagate_inbounds @inline overload_broadcast_getindex(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices
Base.@propagate_inbounds @inline overload_broadcast_getindex(::Ref{Type{T}}, I) where {T} = T
# Tuples are statically known to be singleton or vector-like
Base.@propagate_inbounds overload_broadcast_getindex(A::Tuple{Any}, I) = A[1]
Base.@propagate_inbounds overload_broadcast_getindex(A::Tuple, I) = error("unhandled") # A[I[1]]
Base.@propagate_inbounds overload_broadcast_getindex(A, I) = A[I]
Base.@propagate_inbounds @inline overload_broadcast_getindex(A::Tuple{Any}, I) = A[1]
Base.@propagate_inbounds @inline overload_broadcast_getindex(A::Tuple, I) = error("unhandled") # A[I[1]]
Base.@propagate_inbounds @generated function overload_broadcast_getindex(bc::Base.Broadcast.Broadcasted, I)
args = Expr[]
for i in 1:length(bc.parameters[4].parameters)
push!(args, Expr(:call, overload_broadcast_getindex, :(bc.args[$i]), :I))
end
expr = Expr(:call, Base.Broadcast._broadcast_getindex_evalf, :(bc.f), args...)
return quote
Base.@_inline_meta
$expr
end
end

Base.@propagate_inbounds @inline overload_broadcast_getindex(A, I) = @inbounds A[I]

@inline function override_bc_materialize(bc)
struct OverrideBCMaterialize{ElType}
end

@inline function (::OverrideBCMaterialize{ElType})(bc) where ElType
if bc.args isa Tuple{AbstractArray} && bc.f === Base.identity
return copy(bc.args[1])
end
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
dest = similar(bc, ElType)
if all(isa_array_or_number, bc.args) && same_sized(bc.args)
@inbounds @simd for I in 1:length(bc)
val = Base.Broadcast._broadcast_getindex_evalf(bc.f, map(Base.Fix2(overload_broadcast_getindex, I), bc.args)...)
dest = @inline similar(bc, ElType)
if same_sized(bc.args)
# dest = @inline similar(first_array(bc.args), ElType)
@inbounds @simd for I in 1:length(bc)
val = overload_broadcast_getindex(bc, I)
dest[I] = val
end
return dest
else
Base.copyto!(dest, bc)
# The existing code is rather slow for broadcast in practice: https://github.com/EnzymeAD/Enzyme.jl/issues/1434
src = @inline Base.Broadcast.preprocess(nothing, bc)
idx = Base.eachindex(src)
@inline Enzyme.Compiler.Interpreter.lindex_v3(idx, dest, src)
return dest
end
return dest
end

struct MultiOp{Position, NumUsed, F1, F2}
Expand Down Expand Up @@ -958,12 +1025,28 @@ end
end
end

@inline function array_or_number(@nospecialize(Ty))::Bool
return Ty <: AbstractArray || Ty <: Number
@inline function bc_or_array_or_number_ty(@nospecialize(Ty::Type))::Bool
if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
return all(bc_or_array_or_number_ty, Ty.parameters[4].parameters)
else
return Ty <: AbstractArray || Ty <: Number || Ty <: Base.RefValue
end
end

@inline function isa_array_or_number(@nospecialize(x))::Bool
return x isa AbstractArray || x isa Number
@inline function has_array(@nospecialize(Ty::Type))::Bool
if Ty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing}
return any(has_array, Ty.parameters[4].parameters)
else
return Ty <: AbstractArray
end
end

@generated function isa_bc_or_array_or_number(x)::Bool
res = bc_or_array_or_number_ty(x)
return quote
Base.@_inline_meta
$res
end
end

@inline function num_or_eltype(@nospecialize(Ty))::Type
Expand All @@ -974,6 +1057,25 @@ end
end
end


## Computation of inferred result type, for empty and concretely inferred cases only
ty_broadcast_getindex_eltype(interp, bc::Type{<:Base.Broadcast.Broadcasted}) = ty_combine_eltypes(interp, bc.parameters[3], (bc.parameters[4].parameters...,))
ty_broadcast_getindex_eltype(interp, A) = eltype(A) # Tuple, Array, etc.

ty_eltypes(interp, ::Tuple{}) = Tuple{}
ty_eltypes(interp, t::Tuple{Any}) = Iterators.TupleOrBottom(ty_broadcast_getindex_eltype(interp, t[1]))
ty_eltypes(interp, t::Tuple{Any,Any}) = Iterators.TupleOrBottom(ty_broadcast_getindex_eltype(interp, t[1]), ty_broadcast_getindex_eltype(interp, t[2]))
ty_eltypes(interp, t::Tuple) = (TT = ty_eltypes(interp, Base.tail(t)); TT === Union{} ? Union{} : Iterators.TupleOrBottom(ty_broadcast_getindex_eltype(interp, t[1]), TT.parameters...))
# eltypes(t::Tuple) = Iterators.TupleOrBottom(ntuple(i -> _broadcast_getindex_eltype(t[i]), Val(length(t)))...)

# Inferred eltype of result of broadcast(f, args...)
function ty_combine_eltypes(interp, f, args::Tuple)
argT = ty_eltypes(interp, args)
argT === Union{} && return Union{}
preprom = Core.Compiler._return_type(interp, Tuple{f, argT.parameters...})
return Base.promote_typejoin_union(preprom)
end

function abstract_call_known(
interp::EnzymeInterpreter{Handler},
@nospecialize(f),
Expand Down Expand Up @@ -1012,20 +1114,24 @@ function abstract_call_known(
if interp.broadcast_rewrite
if f === Base.materialize && length(argtypes) == 2
bcty = widenconst(argtypes[2])
if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && all(array_or_number, bcty.parameters[4].parameters) && any(Base.Fix2(Base.:<:, AbstractArray), bcty.parameters[4].parameters)
if Base.isconcretetype(bcty) && bcty <: Base.Broadcast.Broadcasted{<:Base.Broadcast.DefaultArrayStyle, Nothing} && bc_or_array_or_number_ty(bcty) && has_array(bcty)
ElType = ty_broadcast_getindex_eltype(interp, bcty)
if ElType !== Union{} && Base.isconcretetype(ElType)
fn2 = Enzyme.Compiler.Interpreter.OverrideBCMaterialize{ElType}()
arginfo2 = ArgInfo(
fargs isa Nothing ? nothing :
[:(Enzyme.Compiler.Interpreter.override_bc_materialize), fargs[2:end]...],
[Core.Const(Enzyme.Compiler.Interpreter.override_bc_materialize), argtypes[2:end]...],
[:(fn2), fargs[2:end]...],
[Core.Const(fn2), argtypes[2:end]...],
)
return Base.@invoke abstract_call_known(
interp::AbstractInterpreter,
Enzyme.Compiler.Interpreter.override_bc_materialize::Any,
fn2::Any,
arginfo2::ArgInfo,
si::StmtInfo,
sv::AbsIntState,
max_methods::Int,
)
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
if isa(inst, LLVM.CallInst)
push!(calls, inst)
# remove illegal invariant.load and jtbaa_const invariants
elseif isa(inst, LLVM.LoadInst)

elseif isa(inst, LLVM.LoadInst)
fn_got, _ = get_base_and_offset(operands(inst)[1]; offsetAllowed=false, inttoptr=false)
fname = String(name(fn_got))
match_ = match(r"^jlplt_(.*)_\d+_got$", fname)
Expand Down Expand Up @@ -468,6 +467,7 @@ function check_ir!(interp, @nospecialize(job::CompilerJob), errors::Vector{IRErr
inst = pop!(calls)
check_ir!(interp, job, errors, imported, inst, calls, mod)
end

return errors
end

Expand Down
Loading
Loading