-
Notifications
You must be signed in to change notification settings - Fork 162
Open
Description
As stated in the title, accum_param_gradients!
does not support scale_factor
for static functions. Calling accum_param_gradients!
with a third argument returns ERROR: Not implemented
, because it defaults to the abstract GFI definition.
This is due to (1) the lack of a generated method definition with the appropriate signature:
Gen.jl/src/static_ir/backprop.jl
Lines 508 to 512 in e5ed96f
push!(generated_functions, quote | |
@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} | |
$(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad) | |
end | |
end) |
And (2) the lack of logic to handle a scale factor in the backward pass for trainable parameter nodes:
Gen.jl/src/static_ir/backprop.jl
Lines 169 to 185 in e5ed96f
function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode) | |
# handle case when it is the return node | |
if node === ir.return_node && node in fwd_marked | |
@assert node in back_marked | |
push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) | |
push!(stmts, :($(gradient_var(node)) += retval_grad)) | |
end | |
if node in fwd_marked && node in back_marked | |
cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref, | |
$(QuoteNode(node.name)))) | |
push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref, | |
$(QuoteNode(node.name)), | |
$cur_param_grad + $(gradient_var(node))))) | |
end | |
end |
Metadata
Metadata
Assignees
Labels
No labels