binary_kernel_reduce Class — pytorch Architecture
Architecture documentation for the binary_kernel_reduce class in Reduce.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/Reduce.h lines 184–248
template <typename ops_t, typename init_t>
void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
using rf_t = decltype(&ops_t::reduce);
using cf_t = decltype(&ops_t::combine);
using pf_t = decltype(&ops_t::project);
using r_traits = binary_function_traits<rf_t>;
using c_traits = binary_function_traits<cf_t>;
using p_traits = unary_function_traits<pf_t>;
using acc_t = typename p_traits::arg1_t;
using data_t = typename r_traits::arg2_t;
static_assert(
all_same<
acc_t,
init_t,
typename r_traits::arg1_t,
typename r_traits::result_type,
typename c_traits::arg1_t,
typename c_traits::arg2_t,
typename c_traits::result_type>::value,
"all accumulate types must match");
static_assert(
std::is_default_constructible_v<acc_t>,
"the accumulate type must be default-constructible"
);
const int num_outputs = iter.noutputs();
iter.foreach_reduced_elt([&ops, &init, num_outputs](TensorIteratorBase &sub_iter) {
auto reduction_body = [&ops, &sub_iter, num_outputs](acc_t acc, int64_t begin, int64_t end) -> acc_t {
int ntensors = sub_iter.ntensors();
sub_iter.serial_for_each([&acc, &ops, num_outputs, ntensors, begin](char** data, const int64_t* strides, int64_t size) {
AT_ASSERT(ntensors - num_outputs == 1);
char *in = data[ntensors - 1];
int64_t stride = strides[ntensors - 1];
for (const auto i : c10::irange(size)) {
acc = ops.reduce(acc, c10::load<data_t>(in), begin + i);
in += stride;
}
}, {begin, end});
return ops.translate_idx(acc, sub_iter.view_offsets()[0]);
};
acc_t total_acc = init;
auto numel = sub_iter.numel();
if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
at::in_parallel_region()) {
total_acc = reduction_body(total_acc, 0, numel);
} else {
int max_threads = at::get_num_threads();
AT_ASSERT(max_threads > 0);
static_assert(
!std::is_same_v<acc_t, bool>,
"Concurrently modifying different references into std::vector<bool> is UB."
);
std::vector<acc_t> buffer((unsigned)max_threads, init);
at::parallel_for(0, numel, internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
auto& acc = buffer[at::get_thread_num()];
acc = reduction_body(acc, begin, end);
}
);
for (const auto i : c10::irange(max_threads)) {
total_acc = ops.combine(total_acc, buffer[i]);
}
}
set_results<r_traits>(ops.project(total_acc), sub_iter, num_outputs);
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free