bernoulli_kernel Class — pytorch Architecture
Architecture documentation for the bernoulli_kernel class in DistributionTemplates.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/DistributionTemplates.h lines 369–399
template<typename RNG>
void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
using self_t = scalar_t;
auto p_cpu = p_.to(kCPU);
auto p = expand_inplace(self, p_cpu);
auto iter = TensorIteratorConfig()
.add_output(self)
.add_const_input(*p)
.check_all_same_dtype(false)
.build();
if (p->scalar_type() == kDouble) {
cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
at::bernoulli_distribution<double> bernoulli(p_val);
return static_cast<self_t>(bernoulli(generator));
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::BFloat16, at::ScalarType::Half,
p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
using p_t = scalar_t;
cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
at::bernoulli_distribution<float> bernoulli(p_val);
return static_cast<self_t>(bernoulli(generator));
});
});
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free