distribution_binary_kernel Class — pytorch Architecture
Architecture documentation for the distribution_binary_kernel class in DistributionTemplates.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/DistributionTemplates.h lines 228–267
template <typename func_t>
void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
static_assert(std::is_same_v<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>, "the first argument of functor must be curandStatePhilox4_32_10_t");
using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
using output_t = typename function_traits<func_t>::result_type;
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
distribution_binary_kernel(sub_iter, philox_args, f);
}
return;
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
int64_t numel = iter.numel();
if (numel == 0) {
return;
}
output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
int64_t grid = (numel + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
if (iter.is_contiguous()) {
distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
numel, f, philox_args, output_data, input_data_1, input_data_2,
TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
numel, f, philox_args, output_data, input_data_1, input_data_2,
make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free