dirichlet_grad_one Class — pytorch Architecture
Architecture documentation for the dirichlet_grad_one class in Distributions.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Distributions.h lines 448–510
template<typename scalar_t, typename accscalar_t>
C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
accscalar_t x_ = static_cast<accscalar_t>(x);
accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
accscalar_t total_ = static_cast<accscalar_t>(total);
const scalar_t beta = total - alpha;
const accscalar_t beta_ = total_ - alpha_;
const scalar_t boundary = total * x * (1 - x);
// Use an asymptotic approximation for x close to 0.
if (x <= 0.5f && boundary < 2.5f) {
return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
}
// Use an asymptotic approximation for x close to 1.
if (x >= 0.5f && boundary < 0.75f) {
return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
}
// Use an asymptotic approximation when alpha and (total - alpha) are both large.
if (alpha > 6 && beta > 6) {
return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
}
// Use a rational correction to an analytic approximation.
static const accscalar_t c[2][3][3][4] = {
{{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
{0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
{-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
{{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
{-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
{0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
{{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
{0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
{0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
{{{1, -0.02924021934, -0.04438342661, 0.007285809825},
{0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
{-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
{{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
{0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
{-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
{{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
{0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
{-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
};
const accscalar_t u = compat_log(x_);
const accscalar_t a = compat_log(alpha_) - u;
const accscalar_t b = compat_log(total_) - a;
const accscalar_t pow_u[3] = {1, u, u * u};
const accscalar_t pow_a[3] = {1, a, a * a};
accscalar_t p = 0.0;
accscalar_t q = 0.0;
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
const accscalar_t ua = pow_u[i] * pow_a[j];
p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
}
}
const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
return static_cast<scalar_t>(p / q * approx);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free