cpu_take_put_kernel Class — pytorch Architecture
Architecture documentation for the cpu_take_put_kernel class in IndexKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/IndexKernel.cpp lines 63–113
template <typename scalar_t, typename func_t>
void cpu_take_put_kernel(
TensorIterator& iter,
const TensorBase& indexed,
bool is_indexed_data_mutated,
const func_t& f,
bool serial_execution=false) {
// This kernel follows the same strategy as `cpu_index_kernel`
// Even though the indexed_tensor is const, we modify it through the data_ptr
// This is a bit dirty, but otherwise it would be necessary to unnecessarily add tensor
// with zero strides to `iter` which would not be much better
// When launch the parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
// to make the whole available thread numbers get more balanced work load and a better cache location.
// The grain size here is chosen by the op benchmark to overcome the thread launch overhead
// Perhaps tweak this number for `put_`? This number was tweaked for `index_put`
constexpr int parallel_grain_size = 3000;
const bool is_contiguous = indexed.is_contiguous();
const auto numel = indexed.numel();
const auto offset_indexed = IndexToOffset(indexed);
auto* indexed_data = is_indexed_data_mutated ?
indexed.data_ptr<scalar_t>()
: const_cast<scalar_t*>(indexed.const_data_ptr<scalar_t>());
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
auto* iterated_data_bytes = data[0];
auto* index_data_bytes = data[1];
for ([[maybe_unused]] const auto elem : c10::irange(n)) {
auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
auto& iterated = *reinterpret_cast<scalar_t*>(iterated_data_bytes);
TORCH_CHECK_INDEX(idx >= -numel && idx < numel,
"out of range: tried to access index ",
idx, " on a tensor of ", numel, " elements.");
if (idx < 0) {
idx += numel;
}
if (!is_contiguous) {
idx = offset_indexed.get(idx);
}
f(iterated, indexed_data, idx);
iterated_data_bytes += strides[0];
index_data_bytes += strides[1];
}
};
if (serial_execution) {
iter.serial_for_each(loop, {0, iter.numel()});
} else {
iter.for_each(loop, parallel_grain_size);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free