reduce Class — pytorch Architecture
Architecture documentation for the reduce class in ScatterGatherKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/ScatterGatherKernel.cpp lines 692–815
template <typename scalar_t, ReductionType reduce>
void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
const int64_t* index_data = index.const_data_ptr<int64_t>();
scalar_t* self_data = self.data_ptr<scalar_t>();
const scalar_t* src_data = src.const_data_ptr<scalar_t>();
const int64_t M = ensure_nonempty_size(self, 0);
const int64_t nnz = ensure_nonempty_size(index, 0);
const int64_t K = index.numel() / nnz;
const int64_t index_upper_bound = M;
auto keys = std::make_unique<int64_t[]>(nnz);
auto values = std::make_unique<int64_t[]>(nnz);
auto keys_tmp = std::make_unique<int64_t[]>(nnz);
auto values_tmp = std::make_unique<int64_t[]>(nnz);
at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
int64_t index = index_data[i];
TORCH_CHECK(index >= 0 && index < index_upper_bound,
"index ", index,
" is out of bounds for dimension ", 0,
" with size ", index_upper_bound);
keys[i] = index;
values[i] = i;
}
});
int64_t* sorted_col_index_keys = nullptr;
int64_t* sorted_col_index_values = nullptr;
std::tie(sorted_col_index_keys, sorted_col_index_values) = fbgemm::radix_sort_parallel(
keys.get(),
values.get(),
keys_tmp.get(),
values_tmp.get(),
nnz,
M);
int num_threads = at::get_num_threads();
std::vector<int64_t> num_uniq(num_threads, 0);
at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
for(const auto i : c10::irange(begin, end)) {
if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
num_uniq[tid]++;
}
}
});
num_uniq[0]++;
for (const auto n : c10::irange(1, num_threads)) {
num_uniq[n] += num_uniq[n - 1];
}
// in case some rows are not written into, num_nonzero_rows will be smaller than M
int64_t num_nonzero_rows = num_uniq[num_threads - 1];
auto row_index_tmp = std::make_unique<int64_t[]>(num_nonzero_rows);
auto row_index_offset_tmp = std::make_unique<int64_t[]>(num_nonzero_rows + 1);
int64_t* row_index = row_index_tmp.get();
int64_t* row_index_offset = row_index_offset_tmp.get();
row_index[0] = sorted_col_index_keys[0];
row_index_offset[0] = 0;
row_index_offset[num_nonzero_rows] = nnz;
at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
int64_t* t_index = row_index + ((tid == 0) ? 1 : num_uniq[tid - 1]);
int64_t* t_index_offset = row_index_offset + ((tid == 0) ? 1 : num_uniq[tid - 1]);
for (const auto i : c10::irange(begin, end)) {
if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
*t_index = sorted_col_index_keys[i];
*t_index_offset = i;
t_index++;
t_index_offset++;
}
}
});
using opmath_t = at::opmath_type<scalar_t>;
Tensor buffer;
opmath_t* buffer_data = nullptr;
static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
if constexpr (need_acc) {
auto acc_type = at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true);
buffer = at::zeros({num_threads, K}, self.options().dtype(acc_type));
buffer_data = buffer.data_ptr<opmath_t>();
}
// TODO: do blocking on col dimension to reduce WR bandwidth
at::parallel_for(0, num_nonzero_rows, 1, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
TORCH_CHECK(tid < num_threads,
"expect thread id smaller than ", num_threads, ", got thread id ", tid);
opmath_t* buffer_ptr = nullptr;
for (const auto m : c10::irange(begin, end)) {
int64_t row = row_index[m];
int64_t off_start = row_index_offset[m];
int64_t off_end = row_index_offset[m + 1];
scalar_t* self_ptr = self_data + row * K;
if constexpr (need_acc) {
buffer_ptr = buffer_data + tid * K;
} else {
buffer_ptr = reinterpret_cast<opmath_t*>(self_ptr);
}
// step 1: reinit rows in `self` if needed
_init<scalar_t, reduce>(self_ptr, buffer_ptr, K, include_self);
// step 2: reduce
for (const auto n : c10::irange(off_start, off_end)) {
int64_t col = sorted_col_index_values[n];
update<scalar_t, reduce>(buffer_ptr, src_data + col * K, K);
}
if constexpr (need_acc) {
vec::convert(buffer_ptr, self_ptr, K);
}
// step 3: finalize
int64_t count = include_self ? 1 : 0;
count += off_end - off_start;
write<scalar_t, reduce>(self_ptr, count, K);
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free