compute_T8 Class — pytorch Architecture
Architecture documentation for the compute_T8 class in LinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LinearAlgebra.cpp lines 2387–2437
template <typename scalar_t>
Tensor compute_T8(const Tensor& A) {
constexpr scalar_t sqrt_177 = 0.1330413469565007072504e+2;
constexpr scalar_t x3 = 2. / 3.;
constexpr scalar_t x1 = x3 * ((1. + sqrt_177) / 88.);
constexpr scalar_t x2 = x3 * ((1. + sqrt_177) / 352.);
constexpr scalar_t x4 = (-271. + 29. * sqrt_177) / (315. * x3);
constexpr scalar_t x5 = (-11. + 11. * sqrt_177) / (1260. * x3);
constexpr scalar_t x6 = (-99. + 11. * sqrt_177) / (5040. * x3);
constexpr scalar_t x7 = (89. - sqrt_177) / (5040. * x3);
constexpr scalar_t y2 = (857. - 58. * sqrt_177) / 630.;
auto As = _allocate_buffer(A, 5);
// 3 for {I, A, A^2}
_fill_matrix_powers(As, A, 3);
// output for A4
auto view_out = As.select(0, 3);
// A4 = A2 * (x1 * A + x2 * A2)
_matmul_impl(
view_out,
// As.select(0, 2) = A^2
As.select(0, 2),
_linear_combination<scalar_t>(
// extract {A, A^2} from As
As.narrow(0, 1, 2),
{x1, x2}
)
);
// output for A8
view_out = As.select(0, 4);
// A8 = (x3 * A2 + A4) * (x4 * I + x5 * A + x6 * A2 + x7 * A4)
_matmul_impl(
view_out,
// x3 * A2 + A4
_linear_combination<scalar_t>(
As.narrow(0, 2, 2),
{x3, 1.0}
),
_linear_combination<scalar_t>(
As.narrow(0, 0, 4),
{x4, x5, x6, x7}
)
);
// return I + A + y2 * A2 + A8;
return _linear_combination<scalar_t>(
As, {1.0, 1.0, y2, 0.0, 1.0}
);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free