Home / Class/ bernoulli_kernel Class — pytorch Architecture

bernoulli_kernel Class — pytorch Architecture

Architecture documentation for the bernoulli_kernel class in DistributionTemplates.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/DistributionTemplates.h lines 651–672

template<typename RNG>
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
  PhiloxCudaState rng_engine_inputs;
  {
    // See Note [Acquire lock when using random generators]
    std::lock_guard<std::mutex> lock(gen->mutex_);
    rng_engine_inputs = gen->philox_cuda_state(10);
  }
  TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
  // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
  const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
  auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
  auto p = expand_inplace(self, p_cuda);
  AT_DISPATCH_ALL_TYPES_AND3(
    at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
      if (std::is_same_v<scalar_t, double>) {
        return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
      } else {
        return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
      }
   });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free