Home / Class/ distribution_nullary_kernel Class — pytorch Architecture

distribution_nullary_kernel Class — pytorch Architecture

Architecture documentation for the distribution_nullary_kernel class in DistributionTemplates.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/DistributionTemplates.h lines 106–155

template<typename scalar_t,
         typename accscalar_t,
         typename dist_func_return_t,
         typename RNG,
         typename dist_t,
         typename transform_t>
void distribution_nullary_kernel(at::TensorIteratorBase& iter,
                                 RNG gen,
                                 const dist_t& dist_func,
                                 const transform_t transform_func) {
  const int unroll_factor = sizeof(dist_func_return_t) / sizeof(accscalar_t);
  TORCH_CHECK(unroll_factor >= 1, "unroll_factor must be >= 1.");
  int64_t numel = iter.numel();
  if (numel == 0) {
    return;
  }

  auto [counter_offset, grid, block] = calc_execution_policy(numel, unroll_factor);
  PhiloxCudaState rng_engine_inputs;
  {
    // See Note [Acquire lock when using random generators]
    std::lock_guard<std::mutex> lock(gen->mutex_);
    rng_engine_inputs = gen->philox_cuda_state(counter_offset);
  }

  if (!iter.can_use_32bit_indexing()) {
    for (auto& sub_iter : iter.with_32bit_indexing()) {
      distribution_nullary_kernel<scalar_t, accscalar_t, dist_func_return_t>(sub_iter,
        gen, dist_func, transform_func);
    }
    return;
  }

  char* out_data = (char*)iter.data_ptr(0);

  auto stream = at::cuda::getCurrentCUDAStream();
  if (iter.is_trivial_1d()) {
    auto strides = iter.get_inner_strides();
    int stride0 = strides[0];
    distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
      numel,
      rng_engine_inputs,
      dist_func,
      [=]__device__(int idx, accscalar_t rand) {
        scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
        *out = transform_func(rand);
      }
    );
    C10_CUDA_KERNEL_LAUNCH_CHECK();
  } else {

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free