Home / Class/ cpu_take_put_kernel Class — pytorch Architecture

cpu_take_put_kernel Class — pytorch Architecture

Architecture documentation for the cpu_take_put_kernel class in IndexKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/IndexKernel.cpp lines 63–113

template <typename scalar_t, typename func_t>
void cpu_take_put_kernel(
    TensorIterator& iter,
    const TensorBase& indexed,
    bool is_indexed_data_mutated,
    const func_t& f,
    bool serial_execution=false) {
  // This kernel follows the same strategy as `cpu_index_kernel`
  // Even though the indexed_tensor is const, we modify it through the data_ptr
  // This is a bit dirty, but otherwise it would be necessary to unnecessarily add tensor
  // with zero strides to `iter` which would not be much better

  // When launch the parallel version, set a relative small grain size less than the INTERNAL::GRAIN_SIZE
  // to make the whole available thread numbers get more balanced work load and a better cache location.
  // The grain size here is chosen by the op benchmark to overcome the thread launch overhead
  // Perhaps tweak this number for `put_`? This number was tweaked for `index_put`
  constexpr int parallel_grain_size = 3000;
  const bool is_contiguous = indexed.is_contiguous();
  const auto numel = indexed.numel();
  const auto offset_indexed = IndexToOffset(indexed);

  auto* indexed_data = is_indexed_data_mutated ?
   indexed.data_ptr<scalar_t>()
   : const_cast<scalar_t*>(indexed.const_data_ptr<scalar_t>());
  auto loop = [&](char** data, const int64_t* strides, int64_t n) {
    auto* iterated_data_bytes = data[0];
    auto* index_data_bytes = data[1];
    for ([[maybe_unused]] const auto elem : c10::irange(n)) {
      auto idx = *reinterpret_cast<int64_t*>(index_data_bytes);
      auto& iterated = *reinterpret_cast<scalar_t*>(iterated_data_bytes);

      TORCH_CHECK_INDEX(idx >= -numel && idx < numel,
                        "out of range: tried to access index ",
                        idx, " on a tensor of ", numel, " elements.");
      if (idx < 0) {
        idx += numel;
      }
      if (!is_contiguous) {
        idx = offset_indexed.get(idx);
      }
      f(iterated, indexed_data, idx);
      iterated_data_bytes += strides[0];
      index_data_bytes += strides[1];
    }
  };
  if (serial_execution) {
    iter.serial_for_each(loop, {0, iter.numel()});
  } else {
    iter.for_each(loop, parallel_grain_size);
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free