Home / Class/ _beta_grad_alpha_mid Class — pytorch Architecture

_beta_grad_alpha_mid Class — pytorch Architecture

Architecture documentation for the _beta_grad_alpha_mid class in Distributions.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/Distributions.h lines 408–441

template<typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
  const accscalar_t total = alpha + beta;
  const accscalar_t mean = alpha / total;
  const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
  if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
    // Avoid the singularity at x = mean.
    const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
                           (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
                           3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
                           (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
                           8 * (1 - x) * (135 * beta - 11)))));
    const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
    const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
    return prefactor_num / (1 - x) * poly / prefactor_den;
  }
  const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
  const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
                             * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
                             / (1 + 1 / (12 * total) + 1 / (288 * total * total));
  const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
  const accscalar_t axbx = alpha * (x - 1) + beta * x;
  const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
  const accscalar_t term1 = term1_num / term1_den;
  const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
  const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
  const accscalar_t term3_den = beta * x + alpha * (x - 1);
  const accscalar_t term3 = term3_num / term3_den;
  const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
                               alpha * compat_log(alpha / (total * x));
  const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
  const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
  return static_cast<scalar_t>(stirling * prefactor * term1234);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free