Home / Class/ host_softmax_backward Class — pytorch Architecture

host_softmax_backward Class — pytorch Architecture

Architecture documentation for the host_softmax_backward class in SoftMax.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/SoftMax.cpp lines 245–300

template <typename scalar_t>
void host_softmax_backward(
    const Tensor& gI,
    const Tensor& grad,
    const Tensor& output,
    int64_t dim,
    bool* mask = nullptr) {

  int64_t outer_size = 1;
  int64_t dim_size = grad.size(dim);
  int64_t inner_size = 1;
  for (const auto i : c10::irange(dim)) {
    outer_size *= grad.size(i);
  }
  for (int64_t i = dim + 1; i < grad.dim(); ++i) {
    inner_size *= grad.size(i);
  }
  int64_t dim_stride = inner_size;
  int64_t outer_stride = dim_size * dim_stride;
  scalar_t* gradInput_data_base = gI.data_ptr<scalar_t>();
  scalar_t* output_data_base = output.data_ptr<scalar_t>();
  scalar_t* gradOutput_data_base = grad.data_ptr<scalar_t>();
  bool* mask_data_base = mask;
  int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
  parallel_for(
      0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
        for (const auto i : c10::irange(begin, end)) {
          int64_t outer_idx = i / inner_size;
          int64_t inner_idx = i % inner_size;
          scalar_t* gradInput_data =
              gradInput_data_base + outer_idx * outer_stride + inner_idx;
          scalar_t* output_data =
              output_data_base + outer_idx * outer_stride + inner_idx;
          const scalar_t* gradOutput_data =
              gradOutput_data_base + outer_idx * outer_stride + inner_idx;
          bool* mask_data = mask_data_base + outer_idx * outer_stride + inner_idx;

          acc_type<scalar_t, false> sum = 0;
          for (const auto d : c10::irange(dim_size)) {
            if (!mask_data[d * dim_stride]) {
              sum +=
                  gradOutput_data[d * dim_stride] * output_data[d * dim_stride];
            }
          }

          for (const auto d : c10::irange(dim_size)) {
            if (mask_data[d * dim_stride]) {
              gradInput_data[d * dim_stride] = 0;
            } else {
              gradInput_data[d * dim_stride] = output_data[d * dim_stride] *
                  (gradOutput_data[d * dim_stride] - sum);
            }
          }
        }
      });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free