Home / Class/ self Class — pytorch Architecture

self Class — pytorch Architecture

Architecture documentation for the self class in ScatterGatherKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/ScatterGatherKernel.cpp lines 168–263

  template <typename func_t>
  void operator()(const Tensor& self, int64_t dim,
    const Tensor& _index, const Scalar& value,
    const std::string& method_name, func_t& kernel_func) {

    Tensor buffer;
    Tensor index = _index.to(ScalarType::Long);
    bool need_acc = isReducedFloatingType(self.scalar_type());
    create_acc_buffer(buffer, self, need_acc);

    auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
    auto index_strides = ensure_nonempty_vec(index.strides().vec());

    // `dim` is traversed in the kernel,
    // that is why index.stride(dim) = 0 and index.size(dim) = 1.
    // Also, index.size(dim) = 1 makes sure that TensorIterator.DimCounter
    // has the following form : (i_1,..., i_{dim-1}, 0, i_{dim+1},...,i_n).
    index_sizes[dim] = 1;
    index_strides[dim] = 0;

    auto iter = TensorIteratorConfig()
      .check_all_same_dtype(false)
      .resize_outputs(false)
      .declare_static_shape(index.sizes(), /*squash_dims=*/dim)
      .add_output(buffer)
      .add_const_input(index)
      .build();

    auto self_dim_stride = ensure_nonempty_stride(buffer, dim);
    auto self_dim_size = ensure_nonempty_size(buffer, dim);

    auto index_dim_stride = ensure_nonempty_stride(index, dim);
    auto index_dim_size = ensure_nonempty_size(index, dim);

    auto index_upper_bound = self_dim_size;

    // since the index dimension is squashed, need to alter the grain size according
    // to keep equal granularity in parallelism.
    int64_t grain_size = std::max((int64_t) 1, at::internal::GRAIN_SIZE / index_dim_size);

    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
      ScalarType::Bool, ScalarType::Half, ScalarType::BFloat16, self.scalar_type(),
      "scatter_gather_scalar_cpu", [&] {
        constexpr auto SELF_ITER_STRIDE_IDX = 0;
        constexpr auto INDEX_ITER_STRIDE_IDX = 1;
        using opmath_t = at::opmath_type<scalar_t>;
        _cpu_scatter_gather_dim_loop<is_scatter_like> loop_func;
        auto loop = [&](char** data, const int64_t* strides, int64_t n) {
          auto* self_data_bytes = data[SELF_ITER_STRIDE_IDX];
          auto* index_data_bytes = data[INDEX_ITER_STRIDE_IDX];
          // we change the order of TensorIterator-dim loop
          // vs dim-TensorIterator loop order depending on
          // whether dim is the last dimension
          if (dim== buffer.dim() - 1) {
            for ([[maybe_unused]] const auto nelem : c10::irange(n)) {
              // dim loop is a separate code block
              // for better performance
              loop_func.template operator()<scalar_t, func_t>(
                (opmath_t*)self_data_bytes, self_dim_stride,
                (int64_t*)index_data_bytes, index_dim_stride,
                value, dim, index_dim_size, index_upper_bound,
                kernel_func);

              self_data_bytes += strides[SELF_ITER_STRIDE_IDX];
              index_data_bytes += strides[INDEX_ITER_STRIDE_IDX];
            }
          }
          else {
            for (const auto i : c10::irange(index_dim_size)) {
              auto* self_data = self_data_bytes;
              auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride);
              for ([[maybe_unused]] const auto nelem : c10::irange(n)) {
                int64_t idx_dim = *(int64_t*)index_data;
                // we are not putting idx_dim in the error message because it disables
                // loop optimization in clang-7
                TORCH_CHECK(idx_dim >= 0 && idx_dim < index_upper_bound,
                            "index ", *(int64_t*)index_data,
                            " is out of bounds for dimension ", dim,
                            " with size ", index_upper_bound);

                auto temp = value.to<scalar_t>();
                kernel_func((opmath_t*)self_data + (is_scatter_like ? idx_dim : i) * self_dim_stride, &temp);

                self_data += strides[SELF_ITER_STRIDE_IDX];
                index_data += strides[INDEX_ITER_STRIDE_IDX];
              }
            }
          }
        };
        iter.for_each(loop, grain_size);
      }
    );
    if (need_acc) {
      self.copy_(buffer);
    }
  }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free