ReversedPackedLayer Class — pytorch Architecture
Architecture documentation for the ReversedPackedLayer class in RNN.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/RNN.cpp lines 994–1046
template<typename hidden_type, typename cell_params>
struct ReversedPackedLayer : Layer<PackedSequence, hidden_type, cell_params> {
using output_type =
typename Layer<PackedSequence, hidden_type, cell_params>::output_type;
ReversedPackedLayer(Cell<hidden_type, cell_params>& cell)
: cell_(cell) {}
output_type operator()(
const PackedSequence& input,
const hidden_type& input_hidden,
const cell_params& params) const override {
std::vector<at::Tensor> step_outputs;
int64_t input_offset = input.data.size(0);
int64_t num_steps = input.batch_sizes.size(0);
const int64_t* batch_sizes = input.batch_sizes.const_data_ptr<int64_t>();
int64_t last_batch_size = batch_sizes[num_steps - 1];
const Tensor* input_ptr = &input.data;
bool pre_compute_input = false;
Tensor input_w;
if (input.data.device().is_cpu()) {
input_w = params.linear_ih(input.data);
input_ptr = &input_w;
pre_compute_input = true;
}
// Here the situation is similar to that above, except we start out with
// the smallest batch size (and a small set of hidden states we actually use),
// and progressively expand the hidden states, as we move backwards over the
// 1D list of inputs.
auto hidden = hidden_slice(input_hidden, 0, batch_sizes[num_steps - 1]);
for (int64_t i = num_steps - 1; i >= 0; --i) {
const int64_t batch_size = batch_sizes[i];
const int64_t inc = batch_size - last_batch_size;
if (inc > 0) {
hidden = hidden_concat(ArrayRef<hidden_type>{
hidden, hidden_slice(input_hidden, last_batch_size, batch_size)});
}
auto step_input =
input_ptr->narrow(0, input_offset - batch_size, batch_size);
input_offset -= batch_size;
last_batch_size = batch_size;
hidden = cell_(step_input, hidden, params, pre_compute_input);
step_outputs.emplace_back(hidden_as_output(hidden));
}
std::reverse(step_outputs.begin(), step_outputs.end());
return {PackedSequence{at::cat(step_outputs, 0), input.batch_sizes},
hidden};
}
Cell<hidden_type, cell_params>& cell_;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free