Skip to content

accum_param_gradients! does not support scale_factor for static functions #387

@ztangent

Description

@ztangent

As stated in the title, accum_param_gradients! does not support scale_factor for static functions. Calling accum_param_gradients! with a third argument returns ERROR: Not implemented, because it defaults to the abstract GFI definition.

This is due to (1) the lack of a generated method definition with the appropriate signature:

push!(generated_functions, quote
@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))}
$(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad)
end
end)

And (2) the lack of logic to handle a scale factor in the backward pass for trainable parameter nodes:

function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode)
# handle case when it is the return node
if node === ir.return_node && node in fwd_marked
@assert node in back_marked
push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing")))
push!(stmts, :($(gradient_var(node)) += retval_grad))
end
if node in fwd_marked && node in back_marked
cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref,
$(QuoteNode(node.name))))
push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref,
$(QuoteNode(node.name)),
$cur_param_grad + $(gradient_var(node)))))
end
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions