-
-
Notifications
You must be signed in to change notification settings - Fork 25
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
Motivation and description
Using trainable, we can walk a model and only apply a function to trainable parameters. But the gradient from Zygote is a named tuple without this information.
Normally, for optimizers this is fine, because our function is applied at every leaf, so we only need a single pass over the model. But it is fairly common to walk entire tree of gradients to compute something (e.g. like a global norm term) first. In this case, we need a pass over gradient outside of the update context.
Possible Implementation
We can include a maptrainable(f, model, [gradient]) (or better name) function that maps a function w.r.t. the trainable parameters of model.
- If another tree like
gradientis passed, thenfis applied to the leaves ofgradient(i.e. approximatelyfmap(TrainableWalk(f), gradient, model)using the last argument to filter the walk). - If no other tree is passed, we just apply
ftomodel(this is a simple walk but maybe it is good for consistency).
ericphanson
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request