ignore_nan Class — pytorch Architecture
Architecture documentation for the ignore_nan class in SumKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/SumKernel.cpp lines 536–612
template <bool ignore_nan, typename scalar_t>
void cascade_sum(TensorIterator &iter) {
iter.output_base().fill_(scalar_t(0));
iter.parallel_reduce(
[&](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t in_strides[] = { strides[1], strides[3] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t out_strides[] = { strides[0], strides[2] };
// Move reduction to be the 1st dim
if (out_strides[0] != 0 && out_strides[1] == 0) {
std::swap(in_strides[0], in_strides[1]);
std::swap(out_strides[0], out_strides[1]);
std::swap(size0, size1);
}
// Special case? - not a true reduction
if (out_strides[0] != 0 && out_strides[1] != 0) {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t outer_strides[] = { strides[2], strides[3] };
UNARY_OUTER_LOOP(data, outer_strides, size1, [&] {
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
char* ptrs[3] = { data[0], data[0], data[1] };
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
int64_t inner_strides[3] = { strides[0], strides[0], strides[1] };
if constexpr (ignore_nan) {
basic_loop(ptrs, inner_strides, 0, size0, [](scalar_t a, scalar_t b) {
auto a_notnan = at::_isnan(a) ? scalar_t(0) : a;
auto b_notnan = at::_isnan(b) ? scalar_t(0) : b;
return a_notnan + b_notnan;
});
} else {
basic_loop(ptrs, inner_strides, 0, size0,
[](scalar_t a, scalar_t b) { return a + b; });
}
});
return;
}
const int64_t out_stride = out_strides[1];
TORCH_INTERNAL_ASSERT(out_strides[0] == 0);
using vec_t = Vectorized<scalar_t>;
using acc_t = at::acc_type<scalar_t, true>;
using vacc_t = Vectorized<acc_t>;
using ScalarLoadPolicy = std::conditional_t<
ignore_nan,
NanSumCastLoadPolicy<scalar_t, acc_t>,
CastLoadPolicy<scalar_t, acc_t>>;
using StorePolicy = CastStoreAccumulate<scalar_t, acc_t>;
if (in_strides[0] == sizeof(scalar_t) && size0 >= vec_t::size()) {
// Contiguous inner reduction
using VecLoadPolicy = std::conditional_t<
ignore_nan,
InnerNanSumCastLoadPolicy<vec_t, vacc_t>,
InnerSumCastLoadPolicy<vec_t, vacc_t>>;
vectorized_inner_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[1], out_stride, size0, size1);
} else if (in_strides[1] == sizeof(scalar_t) && size1 >= vec_t::size()) {
// Contiguous outer reduction
using VecLoadPolicy = std::conditional_t<
ignore_nan,
OuterNanSumCastLoadPolicy<vec_t, vacc_t>,
OuterSumCastLoadPolicy<vec_t, vacc_t>>;
vectorized_outer_sum<acc_t, VecLoadPolicy, ScalarLoadPolicy, StorePolicy>(
data, in_strides[0], out_stride, size0, size1);
} else if (in_strides[0] < in_strides[1]) {
scalar_inner_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides, out_stride, size0, size1);
} else {
scalar_outer_sum<acc_t, ScalarLoadPolicy, StorePolicy>(
data, in_strides, out_stride, size0, size1);
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free