Home / Class/ spmm_reduce_backward_other_arg_kernel_impl Class — pytorch Architecture

spmm_reduce_backward_other_arg_kernel_impl Class — pytorch Architecture

Architecture documentation for the spmm_reduce_backward_other_arg_kernel_impl class in SpmmReduceKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/SpmmReduceKernel.cpp lines 375–428

template <typename scalar_t, typename index_t>
void spmm_reduce_backward_other_arg_kernel_impl(
    const Tensor& grad_other,
    const Tensor& grad_out_,
    const Tensor& col_indices,
    const Tensor& values,
    const Tensor& arg_out_) {

  int64_t nnz = values.numel();
  if (nnz == 0) {
    return;
  }

  auto grad_out = grad_out_.contiguous();
  auto arg_out = arg_out_.contiguous();

  scalar_t* grad_other_data = grad_other.data_ptr<scalar_t>();
  const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
  auto col_data = col_indices.accessor<const index_t, 1>();
  auto values_data = values.accessor<const scalar_t, 1>();
  const index_t* arg_out_data = arg_out.const_data_ptr<index_t>();

  int64_t M = grad_out.size(0);
  int64_t K = grad_out.size(1);
  auto grad = at::empty({M, K}, grad_out.options());
  scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();

  at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
    for (const auto m : c10::irange(begin, end)) {
      const scalar_t* grad_out_ptr = grad_out_data + m * K;
      scalar_t* grad_ptr = grad_data + m * K;
      const index_t* arg_out_ptr = arg_out_data + m * K;

      for (const auto k : c10::irange(K)) {
        if (arg_out_ptr[k] == index_t(nnz)) {
          grad_ptr[k] = scalar_t(0);
        } else {
          grad_ptr[k] = values_data[arg_out_ptr[k]] * grad_out_ptr[k];
        }
      }
    }
  });

  // scatter_add, consider to parallel this with atomic
  for (const auto m : c10::irange(M)) {
    for (const auto k : c10::irange(K)) {
      index_t ind = arg_out_data[m * K + k];
      if (ind != index_t(nnz)) {
        index_t col = col_data[ind];
        grad_other_data[col * K + k] += grad_data[m * K + k];
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free