Home / Class/ parallel_sparse_csr Class — pytorch Architecture

parallel_sparse_csr Class — pytorch Architecture

Architecture documentation for the parallel_sparse_csr class in utils.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/utils.h lines 178–216

template <typename index_t, typename F>
inline void parallel_sparse_csr(
    const TensorAccessor<index_t, 1>& crow_acc,
    const int64_t M,
    const int64_t nnz,
    const F& f) {
  TORCH_CHECK(crow_acc.size(0) == M + 1);

  // directly parallel on `M` may lead to load imbalance,
  // statically determine thread partition here to average payload
  // for each thread.
  int num_threads = at::get_num_threads();
  std::vector<int64_t> thread_splits(num_threads + 1, M);

  int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));

  thread_splits[0] = 0;
  int64_t sum = 0;
  int64_t t = 1;
  for (const auto m : c10::irange(M)) {
    int64_t row_start = crow_acc[m];
    int64_t row_end = crow_acc[m + 1];
    sum += row_end - row_start;
    if (sum > t * thread_averge_payload) {
      thread_splits[t] = m;
      t++;
    }
  }
  // need to restore the last index,
  // due to rounding error when calculating `thread_averge_payload`.
  thread_splits[num_threads] = M;

  at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
    int tid = at::get_thread_num();
    int64_t begin = thread_splits[tid];
    int64_t end = thread_splits[tid + 1];
    f(begin, end);
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free