Skip to content

example: add backward propagation to vanilla rnn example #3329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

raistefintel
Copy link
Contributor

@raistefintel raistefintel commented May 26, 2025

Description

Adding backward propagation primitive to vanilla_rnn.cpp example.

Fixes # (github issue)

Checklist

General

only tested the new example with oneDNN v3.8

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  • Have you formatted the code using clang-format?

@raistefintel raistefintel requested a review from a team as a code owner May 26, 2025 10:58
read_from_dnnl_memory(dst_layer_data.data(), dst_layer_mem);
//
// User updates weights and bias using diffs
//
Copy link
Contributor

@shu1chen shu1chen May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is recommended to read the data from the dst memory at the end, following the structure used in the other examples. Why remove it from the original example?

// Create dnnl::stream.
dnnl::stream engine_stream(engine);
dnnl::engine eng = dnnl::engine(engine_kind, 0);
dnnl::stream s = dnnl::stream(eng);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion: to have consistency with the other examples, consider using engine_stream as the stream name. Also engine for dnnl::engine

.execute(engine_stream, user_weights_layer_mem,
weights_layer_mem);
}
// We also create workspace memory based on the information from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was the reordering part removed from the original example?

auto src_iter_mem = memory(vanilla_rnn_pd.src_iter_desc(), engine);
auto dst_iter_mem = memory(vanilla_rnn_pd.dst_iter_desc(), engine);
auto workspace_mem = memory(vanilla_rnn_pd.workspace_desc(), engine);
auto workspace_memory = create_ws(vanilla_rnn_pd);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were src_iter_mem and dst_iter_mem removed from the original example? I believe they are also necessary for backward propagation.

auto vanilla_rnn_pd = vanilla_rnn_forward::primitive_desc(engine,
prop_kind::forward_training, dnnl::algorithm::eltwise_tanh,
rnn_direction::unidirectional_left2right, src_layer_md, src_iter_md,
auto vanilla_rnn_pd = vanilla_rnn_forward::primitive_desc(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: consider renaming vanilla_rnn_pd and vanilla_rnn_bwd_desc to follow a consistent naming pattern. The same applies to vanilla_rnn_prim and vanilla_rnn_backward_prim. Also vanilla_rnn_args and vanilla_rnn_bwd_args. And others

/// > Annotated version: @ref vanilla_rnn_example_cpp
///
/// @page vanilla_rnn_example_cpp_short
/// @copybrief vanilla_rnn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? Please check whether this is consistent with the other examples.

///
/// @page vanilla_rnn_fwd_bwd_cpp RNN f32 training example
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name vanilla_rnn_fwd_bwd_cpp needs to be updated here

/// This C++ API example demonstrates how to create and execute a
/// [Vanilla RNN](@ref dev_guide_rnn) primitive in forward training propagation
/// [Vanilla RNN](@ref dev_guide_rnn) primitive in forward and backward training propagation
/// mode.
///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were the Key optimizations included in this example comments removed from the original example? Please ensure the structure aligns with that of the other existing examples.

@raistefintel raistefintel marked this pull request as draft May 27, 2025 15:43
@raistefintel
Copy link
Contributor Author

Hi @shu1chen, please check the updated version

@raistefintel raistefintel marked this pull request as ready for review May 28, 2025 14:54
@raistefintel raistefintel requested a review from shu1chen May 28, 2025 14:54
Copy link
Contributor

@shu1chen shu1chen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To pass the clang-format check, please format the example file before pushing your commit. It looks good to me otherwise.

{DNNL_ARG_WEIGHTS_LAYER, weights_layer_bwd_mem});
vanilla_rnn_bwd_args.insert({DNNL_ARG_WEIGHTS_ITER, weights_iter_mem});
vanilla_rnn_bwd_args.insert({DNNL_ARG_BIAS, bias_bwd_mem});
vanilla_rnn_bwd_args.insert({DNNL_ARG_DST_LAYER, dst_layer_mem});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the backward part also need src_iter_mem and dst_iter_mem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants