Skip to content

How to use dnnl_rnn_flags_diff_weights_overwrite in RNN LSTM primitives creation? #3995

@lacak-sk

Description

@lacak-sk

There is stated in documentation, that backward primitive ACCUMULATES diff weights (diff_weights_layer, diff_weights_iter) and bias (diff_bias): "The RNN primitive backward pass accumulates gradients to its weight outputs. Hence, these tensors should be properly initialized to zero before their first use, and can be reused across calls to accumulate gradients if need be. This behavior can be altered by the RNN flag diff_weights_overwrite. If this flag is set weight gradients will be initialized by zeros by the RNN primitive."

There is argument "flags" in dnnl_lstm_forward_primitive_desc_create() and also in dnnl_lstm_backward_primitive_desc_create(), but this argument is described as "Unused."

So it is unclear to me how to set this flag (dnnl_rnn_flags_diff_weights_overwrite) in C API?

Does I understad it correctly, that I MUST set this flag if I execute multiple forward/backward passes for multiple samples and I want to get diffs for actual pass and use these diffs to update weights/biases and then process next sample and so on ...?

And these diffs computed by backward pass are SUMMED over time steps (and batch)? So if I have 200 tokens in sequence (200 time steps) I must update weights by: [diffs] * [LR]/200, where LR is learning rate?
(in general: [weights] = [weights] - [diff_weights] * [learning_rate] / ([batch_size] * [time_steps]) ?)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions