Home / Class/ bernoulli_kernel Class — pytorch Architecture

bernoulli_kernel Class — pytorch Architecture

Architecture documentation for the bernoulli_kernel class in DistributionTemplates.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/DistributionTemplates.h lines 369–399

template<typename RNG>
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
  AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
  self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
    // See Note [Acquire lock when using random generators]
    std::lock_guard<std::mutex> lock(generator->mutex_);
    using self_t = scalar_t;
    auto p_cpu = p_.to(kCPU);
    auto p = expand_inplace(self, p_cpu);
    auto iter = TensorIteratorConfig()
        .add_output(self)
        .add_const_input(*p)
        .check_all_same_dtype(false)
        .build();
    if (p->scalar_type() == kDouble) {
      cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
        at::bernoulli_distribution<double> bernoulli(p_val);
        return static_cast<self_t>(bernoulli(generator));
      });
    } else {
      AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
      p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
        using p_t = scalar_t;
        cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
          at::bernoulli_distribution<float> bernoulli(p_val);
          return static_cast<self_t>(bernoulli(generator));
        });
      });
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free