Home / Class/ mkldnn_impl Class — pytorch Architecture

mkldnn_impl Class — pytorch Architecture

Architecture documentation for the mkldnn_impl class in RNN.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/mkldnn/RNN.cpp lines 557–572

template<typename hidden_type>
std::pair<Tensor, hidden_type> mkldnn_impl(
    const Tensor& input, const hidden_type& hidden,
    TensorList params, bool has_biases, ideep::rnn_kind mode,
    int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
  auto [hx, cx] = unpack_hidden(hidden);
  int64_t hidden_size = hx.size(2);

  auto mkldnn_output = mkldnn_rnn(
      input, params, has_biases ? 4 : 2,
      hx, cx, static_cast<int>(mode), hidden_size, num_layers, has_biases, batch_first, dropout_p,
      train, bidirectional, /*batch_sizes*/{});

  return {std::get<0>(mkldnn_output),
          pack_hidden<hidden_type>(std::get<1>(mkldnn_output), std::get<2>(mkldnn_output))};
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free