-
Notifications
You must be signed in to change notification settings - Fork 1k
example: add backward propagation to vanilla rnn example #3329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
read_from_dnnl_memory(dst_layer_data.data(), dst_layer_mem); | ||
// | ||
// User updates weights and bias using diffs | ||
// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is recommended to read the data from the dst memory at the end, following the structure used in the other examples. Why remove it from the original example?
examples/primitives/vanilla_rnn.cpp
Outdated
// Create dnnl::stream. | ||
dnnl::stream engine_stream(engine); | ||
dnnl::engine eng = dnnl::engine(engine_kind, 0); | ||
dnnl::stream s = dnnl::stream(eng); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor suggestion: to have consistency with the other examples, consider using engine_stream
as the stream name. Also engine
for dnnl::engine
.execute(engine_stream, user_weights_layer_mem, | ||
weights_layer_mem); | ||
} | ||
// We also create workspace memory based on the information from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was the reordering part removed from the original example?
examples/primitives/vanilla_rnn.cpp
Outdated
auto src_iter_mem = memory(vanilla_rnn_pd.src_iter_desc(), engine); | ||
auto dst_iter_mem = memory(vanilla_rnn_pd.dst_iter_desc(), engine); | ||
auto workspace_mem = memory(vanilla_rnn_pd.workspace_desc(), engine); | ||
auto workspace_memory = create_ws(vanilla_rnn_pd); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why were src_iter_mem
and dst_iter_mem
removed from the original example? I believe they are also necessary for backward propagation.
examples/primitives/vanilla_rnn.cpp
Outdated
auto vanilla_rnn_pd = vanilla_rnn_forward::primitive_desc(engine, | ||
prop_kind::forward_training, dnnl::algorithm::eltwise_tanh, | ||
rnn_direction::unidirectional_left2right, src_layer_md, src_iter_md, | ||
auto vanilla_rnn_pd = vanilla_rnn_forward::primitive_desc( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: consider renaming vanilla_rnn_pd
and vanilla_rnn_bwd_desc
to follow a consistent naming pattern. The same applies to vanilla_rnn_prim
and vanilla_rnn_backward_prim
. Also vanilla_rnn_args
and vanilla_rnn_bwd_args
. And others
examples/primitives/vanilla_rnn.cpp
Outdated
/// > Annotated version: @ref vanilla_rnn_example_cpp | ||
/// | ||
/// @page vanilla_rnn_example_cpp_short | ||
/// @copybrief vanilla_rnn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change? Please check whether this is consistent with the other examples.
examples/primitives/vanilla_rnn.cpp
Outdated
/// | ||
/// @page vanilla_rnn_fwd_bwd_cpp RNN f32 training example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name vanilla_rnn_fwd_bwd_cpp
needs to be updated here
/// This C++ API example demonstrates how to create and execute a | ||
/// [Vanilla RNN](@ref dev_guide_rnn) primitive in forward training propagation | ||
/// [Vanilla RNN](@ref dev_guide_rnn) primitive in forward and backward training propagation | ||
/// mode. | ||
/// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why were the Key optimizations included in this example
comments removed from the original example? Please ensure the structure aligns with that of the other existing examples.
Hi @shu1chen, please check the updated version |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To pass the clang-format check, please format the example file before pushing your commit. It looks good to me otherwise.
{DNNL_ARG_WEIGHTS_LAYER, weights_layer_bwd_mem}); | ||
vanilla_rnn_bwd_args.insert({DNNL_ARG_WEIGHTS_ITER, weights_iter_mem}); | ||
vanilla_rnn_bwd_args.insert({DNNL_ARG_BIAS, bias_bwd_mem}); | ||
vanilla_rnn_bwd_args.insert({DNNL_ARG_DST_LAYER, dst_layer_mem}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the backward part also need src_iter_mem
and dst_iter_mem
?
Description
Adding backward propagation primitive to vanilla_rnn.cpp example.
Fixes # (github issue)
Checklist
General
only tested the new example with oneDNN v3.8
make test
andmake test_benchdnn_*
) pass locally for each commit?