Home / Class/ compute_T8 Class — pytorch Architecture

compute_T8 Class — pytorch Architecture

Architecture documentation for the compute_T8 class in LinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebra.cpp lines 2387–2437

template <typename scalar_t>
Tensor compute_T8(const Tensor& A) {
  constexpr scalar_t sqrt_177 = 0.1330413469565007072504e+2;
  constexpr scalar_t x3 = 2. / 3.;
  constexpr scalar_t x1 = x3 * ((1. + sqrt_177) / 88.);
  constexpr scalar_t x2 = x3 * ((1. + sqrt_177) / 352.);
  constexpr scalar_t x4 = (-271. + 29. * sqrt_177) / (315. * x3);
  constexpr scalar_t x5 = (-11. + 11. * sqrt_177) / (1260. * x3);
  constexpr scalar_t x6 = (-99. + 11. * sqrt_177) / (5040. * x3);
  constexpr scalar_t x7 = (89. - sqrt_177) / (5040. * x3);
  constexpr scalar_t y2 = (857. - 58. * sqrt_177) / 630.;

  auto As = _allocate_buffer(A, 5);
  // 3 for {I, A, A^2}
  _fill_matrix_powers(As, A, 3);

  // output for A4
  auto view_out = As.select(0, 3);
  // A4 =  A2 * (x1 * A + x2 * A2)
  _matmul_impl(
    view_out,
    // As.select(0, 2) = A^2
    As.select(0, 2),
    _linear_combination<scalar_t>(
      // extract {A, A^2} from As
      As.narrow(0, 1, 2),
      {x1, x2}
    )
  );

  // output for A8
  view_out = As.select(0, 4);
  // A8 = (x3 * A2 + A4) * (x4 * I + x5 * A + x6 * A2 + x7 * A4)
  _matmul_impl(
    view_out,
    // x3 * A2 + A4
    _linear_combination<scalar_t>(
      As.narrow(0, 2, 2),
      {x3, 1.0}
    ),
    _linear_combination<scalar_t>(
      As.narrow(0, 0, 4),
      {x4, x5, x6, x7}
    )
  );

  // return I + A + y2 * A2 + A8;
  return _linear_combination<scalar_t>(
    As, {1.0, 1.0, y2, 0.0, 1.0}
  );
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free