Home / Class/ batch_iterator_with_broadcasting Class — pytorch Architecture

batch_iterator_with_broadcasting Class — pytorch Architecture

Architecture documentation for the batch_iterator_with_broadcasting class in LinearAlgebraUtils.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebraUtils.h lines 196–262

template<typename scalar_t, typename func_t>
void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
  IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
  IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);

  auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
  auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);

  TensorIterator iter = TensorIteratorConfig()
    .set_check_mem_overlap(false)
    .check_all_same_dtype(false)
    .resize_outputs(false)
    .add_output(b_linear_batch_idx)
    .add_input(a_linear_batch_idx)
    .build();

  auto m = a.size(-2);
  auto n = a.size(-1);
  auto a_3d = a.view({batchCount(a), m, n});
  auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});

  auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
  Tensor a_buffer, a_was_accessed, a_buffer_3d;
  std::function<void(int64_t)> check_if_copy_needed_for_a
    = [](int64_t /*a_curr_linear_batch_idx*/){};
  if (a_broadcasts_over_b) {
    a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
      .copy_(a);
    a_was_accessed = at::zeros(batchCount(a), at::kBool);
    a_buffer_3d = a_buffer.view({batchCount(a), m, n});
    check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
      auto* a_was_accessed_flag = a_was_accessed
        .select(0, a_curr_linear_batch_idx)
        .data_ptr<bool>();
      if (!(*a_was_accessed_flag)) {
        *a_was_accessed_flag = true;
      }
      else {
        a_3d.select(0, a_curr_linear_batch_idx)
          .copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
      }
    };
  }

  auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
    auto* b_batch_idx_ptr = data[0];
    auto* a_batch_idx_ptr = data[1];

    for ([[maybe_unused]] const auto elem : c10::irange(nelems)) {
      auto b_curr_linear_batch_idx =
          *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
      auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);

      check_if_copy_needed_for_a(a_curr_linear_batch_idx);

      auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
        .data_ptr<scalar_t>();
      auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
        .data_ptr<scalar_t>();
      f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);

      b_batch_idx_ptr += strides[0];
      a_batch_idx_ptr += strides[1];
    }
  };
  iter.serial_for_each(loop, {0, batchCount(b)});
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free