spmm_reduce_backward_other_arg_kernel_impl Class — pytorch Architecture
Architecture documentation for the spmm_reduce_backward_other_arg_kernel_impl class in SpmmReduceKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/SpmmReduceKernel.cpp lines 375–428
template <typename scalar_t, typename index_t>
void spmm_reduce_backward_other_arg_kernel_impl(
const Tensor& grad_other,
const Tensor& grad_out_,
const Tensor& col_indices,
const Tensor& values,
const Tensor& arg_out_) {
int64_t nnz = values.numel();
if (nnz == 0) {
return;
}
auto grad_out = grad_out_.contiguous();
auto arg_out = arg_out_.contiguous();
scalar_t* grad_other_data = grad_other.data_ptr<scalar_t>();
const scalar_t* grad_out_data = grad_out.const_data_ptr<scalar_t>();
auto col_data = col_indices.accessor<const index_t, 1>();
auto values_data = values.accessor<const scalar_t, 1>();
const index_t* arg_out_data = arg_out.const_data_ptr<index_t>();
int64_t M = grad_out.size(0);
int64_t K = grad_out.size(1);
auto grad = at::empty({M, K}, grad_out.options());
scalar_t* grad_data = grad.mutable_data_ptr<scalar_t>();
at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
for (const auto m : c10::irange(begin, end)) {
const scalar_t* grad_out_ptr = grad_out_data + m * K;
scalar_t* grad_ptr = grad_data + m * K;
const index_t* arg_out_ptr = arg_out_data + m * K;
for (const auto k : c10::irange(K)) {
if (arg_out_ptr[k] == index_t(nnz)) {
grad_ptr[k] = scalar_t(0);
} else {
grad_ptr[k] = values_data[arg_out_ptr[k]] * grad_out_ptr[k];
}
}
}
});
// scatter_add, consider to parallel this with atomic
for (const auto m : c10::irange(M)) {
for (const auto k : c10::irange(K)) {
index_t ind = arg_out_data[m * K + k];
if (ind != index_t(nnz)) {
index_t col = col_data[ind];
grad_other_data[col * K + k] += grad_data[m * K + k];
}
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free