Home / Class/ compute_T4 Class — pytorch Architecture

compute_T4 Class — pytorch Architecture

Architecture documentation for the compute_T4 class in LinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebra.cpp lines 2362–2385

template <typename scalar_t>
Tensor compute_T4(const Tensor& A) {
  auto As = _allocate_buffer(A, 4);
  // 3 for {I, A, A^2}
  _fill_matrix_powers(As, A, 3);

  // output for A^2 * (I / 2 + A / 6 + A^2 / 24)
  auto view_out = As.select(0, 3);
  _matmul_impl(
    view_out,
    // contains A^2
    As.select(0, 2),
    // computes (I / 2 + A / 6 + A^2 / 24)
    _linear_combination<scalar_t>(
      As.narrow(0, 0, 3),
      {1 / 2.0, 1 / 6.0, 1 / 24.0}
    )
  );

  // I + A + A^2 * (I / 2 + A / 6 + A^2 / 24)
  return _linear_combination<scalar_t>(
    As, {1.0, 1.0, 0.0, 1.0}
  );
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free