arg_t Class — pytorch Architecture
Architecture documentation for the arg_t class in SharedReduceOps.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/SharedReduceOps.h lines 461–489
template <typename comp_t>
struct MinMaxReductionOps {
using scalar_t = typename binary_function_traits<comp_t>::arg1_t;
using index_t = int64_t;
using arg_t = detail::pair<scalar_t, index_t>;
static C10_DEVICE arg_t project(arg_t arg) {
return arg;
}
static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) {
return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx);
}
static C10_DEVICE arg_t combine(arg_t a, arg_t b) {
return comp_t{}(a.first, b.first, a.second, b.second) ? a : b;
}
static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) {
return {a.first, a.second + base_idx};
}
#if defined(__CUDACC__) || defined(__HIPCC__)
static C10_DEVICE arg_t warp_shfl_down(arg_t arg, int offset) {
return arg_t(WARP_SHFL_DOWN(arg.first, offset),
WARP_SHFL_DOWN(arg.second, offset));
}
#endif
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free