Home / Class/ reduce Class — pytorch Architecture

reduce Class — pytorch Architecture

Architecture documentation for the reduce class in ScatterGatherKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/ScatterGatherKernel.cpp lines 692–815

template <typename scalar_t, ReductionType reduce>
void cpu_scatter_reduce_expanded_index(const Tensor& self, const Tensor& index, const Tensor& src, bool include_self) {
  const int64_t* index_data = index.const_data_ptr<int64_t>();
  scalar_t* self_data = self.data_ptr<scalar_t>();
  const scalar_t* src_data = src.const_data_ptr<scalar_t>();

  const int64_t M = ensure_nonempty_size(self, 0);
  const int64_t nnz = ensure_nonempty_size(index, 0);
  const int64_t K = index.numel() / nnz;

  const int64_t index_upper_bound = M;

  auto keys = std::make_unique<int64_t[]>(nnz);
  auto values = std::make_unique<int64_t[]>(nnz);
  auto keys_tmp = std::make_unique<int64_t[]>(nnz);
  auto values_tmp = std::make_unique<int64_t[]>(nnz);
  at::parallel_for(0, nnz, 1, [&](int64_t begin, int64_t end) {
    for (const auto i : c10::irange(begin, end)) {
      int64_t index = index_data[i];
      TORCH_CHECK(index >= 0 && index < index_upper_bound,
                  "index ", index,
                  " is out of bounds for dimension ", 0,
                  " with size ", index_upper_bound);
      keys[i] = index;
      values[i] = i;
    }
  });

  int64_t* sorted_col_index_keys = nullptr;
  int64_t* sorted_col_index_values = nullptr;
  std::tie(sorted_col_index_keys, sorted_col_index_values) = fbgemm::radix_sort_parallel(
      keys.get(),
      values.get(),
      keys_tmp.get(),
      values_tmp.get(),
      nnz,
      M);

  int num_threads = at::get_num_threads();
  std::vector<int64_t> num_uniq(num_threads, 0);
  at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    for(const auto i : c10::irange(begin, end)) {
      if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
        num_uniq[tid]++;
      }
    }
  });
  num_uniq[0]++;
  for (const auto n : c10::irange(1, num_threads)) {
    num_uniq[n] += num_uniq[n - 1];
  }

  // in case some rows are not written into, num_nonzero_rows will be smaller than M
  int64_t num_nonzero_rows = num_uniq[num_threads - 1];
  auto row_index_tmp = std::make_unique<int64_t[]>(num_nonzero_rows);
  auto row_index_offset_tmp = std::make_unique<int64_t[]>(num_nonzero_rows + 1);
  int64_t* row_index = row_index_tmp.get();
  int64_t* row_index_offset = row_index_offset_tmp.get();
  row_index[0] = sorted_col_index_keys[0];
  row_index_offset[0] = 0;
  row_index_offset[num_nonzero_rows] = nnz;

  at::parallel_for(1, nnz, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    int64_t* t_index = row_index + ((tid == 0) ? 1 : num_uniq[tid - 1]);
    int64_t* t_index_offset = row_index_offset + ((tid == 0) ? 1 : num_uniq[tid - 1]);
    for (const auto i : c10::irange(begin, end)) {
      if (sorted_col_index_keys[i] != sorted_col_index_keys[i - 1]) {
        *t_index = sorted_col_index_keys[i];
        *t_index_offset = i;
        t_index++;
        t_index_offset++;
      }
    }
  });

  using opmath_t = at::opmath_type<scalar_t>;
  Tensor buffer;
  opmath_t* buffer_data = nullptr;
  static constexpr bool need_acc = is_reduced_floating_point_v<scalar_t>;
  if constexpr (need_acc) {
    auto acc_type = at::toAccumulateType(self.scalar_type(), /*is_cuda=*/true);
    buffer = at::zeros({num_threads, K}, self.options().dtype(acc_type));
    buffer_data = buffer.data_ptr<opmath_t>();
  }

  // TODO: do blocking on col dimension to reduce WR bandwidth
  at::parallel_for(0, num_nonzero_rows, 1, [&](int64_t begin, int64_t end) {
    int tid = at::get_thread_num();
    TORCH_CHECK(tid < num_threads,
                "expect thread id smaller than ", num_threads, ", got thread id ", tid);
    opmath_t* buffer_ptr = nullptr;

    for (const auto m : c10::irange(begin, end)) {
      int64_t row = row_index[m];
      int64_t off_start = row_index_offset[m];
      int64_t off_end = row_index_offset[m + 1];
      scalar_t* self_ptr = self_data + row * K;
      if constexpr (need_acc) {
        buffer_ptr = buffer_data + tid * K;
      } else {
        buffer_ptr = reinterpret_cast<opmath_t*>(self_ptr);
      }

      // step 1: reinit rows in `self` if needed
      _init<scalar_t, reduce>(self_ptr, buffer_ptr, K, include_self);

      // step 2: reduce
      for (const auto n : c10::irange(off_start, off_end)) {
        int64_t col = sorted_col_index_values[n];
        update<scalar_t, reduce>(buffer_ptr, src_data + col * K, K);
      }
      if constexpr (need_acc) {
        vec::convert(buffer_ptr, self_ptr, K);
      }

      // step 3: finalize
      int64_t count = include_self ? 1 : 0;
      count += off_end - off_start;
      write<scalar_t, reduce>(self_ptr, count, K);
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free