btrs Class — pytorch Architecture
Architecture documentation for the btrs class in Distributions.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Distributions.h lines 164–215
template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
scalar_t k;
accscalar_t U, V, us;
// This is spq in the paper.
const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
// Other coefficients for Transformed Rejection sampling.
const accscalar_t b = 1.15 + 2.53 * stddev;
const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
const accscalar_t c = count * prob + 0.5;
const accscalar_t v_r = 0.92 - 4.2 / b;
const accscalar_t r = prob / (1 - prob);
const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
const accscalar_t m = compat_floor((count + 1) * prob);
while (true) {
U = standard_uniform.sample() - 0.5;
V = standard_uniform.sample();
us = 0.5 - compat_abs(U);
k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
// Reject non-sensical answers.
if (k < 0 || k > count) {
continue;
}
// Region for which the box is tight, and we can return our calculated value.
// This should happen 0.86 * v_r times. In the limit as n * p is large,
// the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
if (us >= 0.07 && V <= v_r) {
return k;
}
// This deviates from Hormann's BTRS algorithm, as there is a log missing.
// For all (u, v) pairs outside of the bounding box, this calculates the
// transformed-reject ratio.
V = compat_log(V * alpha / (a / (us * us) + b));
accscalar_t upperbound =
((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
(count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
(k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
if (V <= upperbound) {
return k;
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free