Home / Class/ dirichlet_grad_one Class — pytorch Architecture

dirichlet_grad_one Class — pytorch Architecture

Architecture documentation for the dirichlet_grad_one class in Distributions.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/Distributions.h lines 448–510

template<typename scalar_t, typename accscalar_t>
C10_HOST_DEVICE inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
  accscalar_t x_ = static_cast<accscalar_t>(x);
  accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
  accscalar_t total_ = static_cast<accscalar_t>(total);

  const scalar_t beta = total - alpha;
  const accscalar_t beta_ = total_ - alpha_;
  const scalar_t boundary = total * x * (1 - x);

  // Use an asymptotic approximation for x close to 0.
  if (x <= 0.5f && boundary < 2.5f) {
    return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
  }

  // Use an asymptotic approximation for x close to 1.
  if (x >= 0.5f && boundary < 0.75f) {
    return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
  }

  // Use an asymptotic approximation when alpha and (total - alpha) are both large.
  if (alpha > 6 && beta > 6) {
    return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
  }

  // Use a rational correction to an analytic approximation.
  static const accscalar_t c[2][3][3][4] = {
    {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
      {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
      {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
     {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
      {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
      {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
     {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
      {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
      {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
    {{{1, -0.02924021934, -0.04438342661, 0.007285809825},
      {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
      {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
     {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
      {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
      {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
     {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
      {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
      {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
  };
  const accscalar_t u = compat_log(x_);
  const accscalar_t a = compat_log(alpha_) - u;
  const accscalar_t b = compat_log(total_) - a;
  const accscalar_t pow_u[3] = {1, u, u * u};
  const accscalar_t pow_a[3] = {1, a, a * a};
  accscalar_t p = 0.0;
  accscalar_t q = 0.0;
  for (int i = 0; i < 3; ++i) {
    for (int j = 0; j < 3; ++j) {
      const accscalar_t ua = pow_u[i] * pow_a[j];
      p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
      q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
    }
  }
  const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
  return static_cast<scalar_t>(p / q * approx);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free