batch_iterator_with_broadcasting Class — pytorch Architecture
Architecture documentation for the batch_iterator_with_broadcasting class in LinearAlgebraUtils.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LinearAlgebraUtils.h lines 196–262
template<typename scalar_t, typename func_t>
void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const func_t& f) {
IntArrayRef a_batch_sizes(a.sizes().data(), a.dim() - 2);
IntArrayRef b_batch_sizes(b.sizes().data(), b.dim() - 2);
auto a_linear_batch_idx = at::arange(batchCount(a)).view(a_batch_sizes);
auto b_linear_batch_idx = at::arange(batchCount(b)).view(b_batch_sizes);
TensorIterator iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(b_linear_batch_idx)
.add_input(a_linear_batch_idx)
.build();
auto m = a.size(-2);
auto n = a.size(-1);
auto a_3d = a.view({batchCount(a), m, n});
auto b_3d = b.view({batchCount(b), b.size(-2), b.size(-1)});
auto a_broadcasts_over_b = (a_batch_sizes != b_batch_sizes);
Tensor a_buffer, a_was_accessed, a_buffer_3d;
std::function<void(int64_t)> check_if_copy_needed_for_a
= [](int64_t /*a_curr_linear_batch_idx*/){};
if (a_broadcasts_over_b) {
a_buffer = at::empty_strided(a.sizes(), a.strides(), a.options())
.copy_(a);
a_was_accessed = at::zeros(batchCount(a), at::kBool);
a_buffer_3d = a_buffer.view({batchCount(a), m, n});
check_if_copy_needed_for_a = [&](int64_t a_curr_linear_batch_idx) {
auto* a_was_accessed_flag = a_was_accessed
.select(0, a_curr_linear_batch_idx)
.data_ptr<bool>();
if (!(*a_was_accessed_flag)) {
*a_was_accessed_flag = true;
}
else {
a_3d.select(0, a_curr_linear_batch_idx)
.copy_(a_buffer_3d.select(0, a_curr_linear_batch_idx));
}
};
}
auto loop = [&](char** data, const int64_t* strides, int64_t nelems) {
auto* b_batch_idx_ptr = data[0];
auto* a_batch_idx_ptr = data[1];
for ([[maybe_unused]] const auto elem : c10::irange(nelems)) {
auto b_curr_linear_batch_idx =
*reinterpret_cast<int64_t*>(b_batch_idx_ptr);
auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);
check_if_copy_needed_for_a(a_curr_linear_batch_idx);
auto* a_working_ptr = a_3d.select(0, a_curr_linear_batch_idx)
.data_ptr<scalar_t>();
auto* b_working_ptr = b_3d.select(0, b_curr_linear_batch_idx)
.data_ptr<scalar_t>();
f(a_working_ptr, b_working_ptr, a_curr_linear_batch_idx);
b_batch_idx_ptr += strides[0];
a_batch_idx_ptr += strides[1];
}
};
iter.serial_for_each(loop, {0, batchCount(b)});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free