Home / Class/ btrs Class — pytorch Architecture

btrs Class — pytorch Architecture

Architecture documentation for the btrs class in Distributions.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/Distributions.h lines 164–215

template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
  scalar_t k;
  accscalar_t U, V, us;

  // This is spq in the paper.
  const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));

  // Other coefficients for Transformed Rejection sampling.
  const accscalar_t b = 1.15 + 2.53 * stddev;
  const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
  const accscalar_t c = count * prob + 0.5;
  const accscalar_t v_r = 0.92 - 4.2 / b;
  const accscalar_t r = prob / (1 - prob);

  const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
  const accscalar_t m = compat_floor((count + 1) * prob);

  while (true) {
    U = standard_uniform.sample() - 0.5;
    V = standard_uniform.sample();

    us = 0.5 - compat_abs(U);
    k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));

    // Reject non-sensical answers.
    if (k < 0 || k > count) {
      continue;
    }
    // Region for which the box is tight, and we can return our calculated value.
    // This should happen 0.86 * v_r times. In the limit as n * p is large,
    // the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
    if (us >= 0.07 && V <= v_r) {
      return k;
    }

    // This deviates from Hormann's BTRS algorithm, as there is a log missing.
    // For all (u, v) pairs outside of the bounding box, this calculates the
    // transformed-reject ratio.
    V = compat_log(V * alpha / (a / (us * us) + b));
    accscalar_t upperbound =
        ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
         (count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
         (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
         stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
         stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));

    if (V <= upperbound) {
      return k;
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free