_beta_grad_alpha_mid Class — pytorch Architecture
Architecture documentation for the _beta_grad_alpha_mid class in Distributions.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Distributions.h lines 408–441
template<typename scalar_t, typename accscalar_t>
C10_DEVICE inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
const accscalar_t total = alpha + beta;
const accscalar_t mean = alpha / total;
const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
// Avoid the singularity at x = mean.
const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
(43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
(453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
8 * (1 - x) * (135 * beta - 11)))));
const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
return prefactor_num / (1 - x) * poly / prefactor_den;
}
const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
* (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
/ (1 + 1 / (12 * total) + 1 / (288 * total * total));
const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
const accscalar_t axbx = alpha * (x - 1) + beta * x;
const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
const accscalar_t term1 = term1_num / term1_den;
const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
const accscalar_t term3_den = beta * x + alpha * (x - 1);
const accscalar_t term3 = term3_num / term3_den;
const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
alpha * compat_log(alpha / (total * x));
const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
return static_cast<scalar_t>(stirling * prefactor * term1234);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free