spmm_reduce_backward_input_arg_kernel_impl Class — pytorch Architecture
Architecture documentation for the spmm_reduce_backward_input_arg_kernel_impl class in SpmmReduceKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/SpmmReduceKernel.cpp lines 291–345
template <typename scalar_t, typename index_t>
void spmm_reduce_backward_input_arg_kernel_impl(
const Tensor& grad_self,
const Tensor& grad_out_,
const Tensor& col_indices,
const Tensor& other_,
const Tensor& arg_out_) {
int64_t nnz = grad_self._nnz();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto other = other_.contiguous();
auto arg_out = arg_out_.contiguous();
auto grad_values = grad_self.values();
auto grad_values_data = grad_values.accessor<scalar_t, 1>();
const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
auto col_data = col_indices.accessor<const index_t, 1>();
const scalar_t* other_data = other.const_data_ptr<scalar_t>();
index_t* arg_out_data = arg_out.data_ptr<index_t>();
int64_t M = grad_out.size(0);
int64_t K = grad_out.size(1);
auto grad = at::empty({M, K}, grad_out.options());
scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
for (const auto m : c10::irange(begin, end)) {
const scalar_t* grad_out_ptr = grad_out_data + m * K;
scalar_t* grad_ptr = grad_data + m * K;
index_t* arg_out_ptr = arg_out_data + m * K;
for (const auto k : c10::irange(K)) {
if (arg_out_ptr[k] == index_t(nnz)) {
grad_ptr[k] = scalar_t(0);
} else {
// collect weight at max/min indices
index_t col = col_data[arg_out_data[m * K + k]];
grad_ptr[k] = other_data[col * K + k] * grad_out_ptr[k];
}
}
}
});
// scatter_add, consider to parallel this with atomic
for (const auto i : c10::irange(M * K)) {
index_t ind = arg_out_data[i];
if (ind != index_t(nnz)) {
grad_values_data[ind] += grad_data[i];
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free