Home / Class/ spmm_reduce_backward_input_arg_kernel_impl Class — pytorch Architecture

spmm_reduce_backward_input_arg_kernel_impl Class — pytorch Architecture

Architecture documentation for the spmm_reduce_backward_input_arg_kernel_impl class in SpmmReduceKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/SpmmReduceKernel.cpp lines 291–345

template <typename scalar_t, typename index_t>
void spmm_reduce_backward_input_arg_kernel_impl(
    const Tensor& grad_self,
    const Tensor& grad_out_,
    const Tensor& col_indices,
    const Tensor& other_,
    const Tensor& arg_out_) {

  int64_t nnz = grad_self._nnz();
  if (nnz == 0) {
    return;
  }

  auto grad_out = grad_out_.contiguous();
  auto other = other_.contiguous();
  auto arg_out = arg_out_.contiguous();

  auto grad_values = grad_self.values();
  auto grad_values_data = grad_values.accessor<scalar_t, 1>();
  const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
  auto col_data = col_indices.accessor<const index_t, 1>();
  const scalar_t* other_data = other.const_data_ptr<scalar_t>();
  index_t* arg_out_data = arg_out.data_ptr<index_t>();

  int64_t M = grad_out.size(0);
  int64_t K = grad_out.size(1);
  auto grad = at::empty({M, K}, grad_out.options());
  scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();

  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    for (const auto m : c10::irange(begin, end)) {
      const scalar_t* grad_out_ptr = grad_out_data + m * K;
      scalar_t* grad_ptr = grad_data + m * K;
      index_t* arg_out_ptr = arg_out_data + m * K;

      for (const auto k : c10::irange(K)) {
        if (arg_out_ptr[k] == index_t(nnz)) {
          grad_ptr[k] = scalar_t(0);
        } else {
          // collect weight at max/min indices
          index_t col = col_data[arg_out_data[m * K + k]];
          grad_ptr[k] = other_data[col * K + k] * grad_out_ptr[k];
        }
      }
    }
  });

  // scatter_add, consider to parallel this with atomic
  for (const auto i : c10::irange(M * K)) {
    index_t ind = arg_out_data[i];
    if (ind != index_t(nnz)) {
      grad_values_data[ind] += grad_data[i];
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free