Home / Class/ mexp_impl Class — pytorch Architecture

mexp_impl Class — pytorch Architecture

Architecture documentation for the mexp_impl class in LinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebra.cpp lines 2642–2717

template <typename scalar_t>
Tensor mexp_impl(
  const Tensor& a,
  std::array<scalar_t, total_n_degs> thetas,
  bool compute_highest_degree_approx = false
) {
  const auto norm = operator_1_norm(a);
  const auto batch_size = a.size(0);
  if (batch_size > 1) {
    compute_highest_degree_approx = true;
  }

  if (!compute_highest_degree_approx) {
    // To prevent undefined behavior which outputs "normal" result from a matrix
    // contains NaN values, we put NaN values in `res`, so if input has NaN values,
    // its computation will be skipped to return the NaN contained `res` directly.
    auto res = at::full_like(a, std::numeric_limits<double>::quiet_NaN(), {},
                             at::MemoryFormat::Contiguous);
    // `norm_cpu` is used to decide which Tensors require which approximation
    // based on their norm. This decision takes place on CPU.
    // It requires moving data back and forth between devices when `a` is on CUDA,
    // but at the cost of only one single CPU-CUDA synchronization (instead of 6),
    // and better performance overall (benchmarked).
    const auto norm_cpu = (a.device().type() == at::kCUDA)
      ? norm.to(at::kCPU) : norm;

    constexpr std::array<
      Tensor(*)(const Tensor&),
      total_n_degs - 1>
    compute_Ts = {
      compute_T1, compute_T2, compute_T4<scalar_t>,
      compute_T8<scalar_t>, compute_T12<scalar_t>
    };

    for (int i = 0; i < total_n_degs - 1; ++i) {
      auto norm_lower_bound = (i == 0) ? static_cast<scalar_t>(-1) : thetas[i - 1];
      auto norm_upper_bound = thetas[i];
      // nonzero returns a 2D tensor, hence squeeze(-1) to make it 1D
      auto idx_curr_norm_interval = (
        (norm_lower_bound < norm_cpu) * (norm_cpu <= norm_upper_bound)
      ).nonzero().squeeze(-1);

      if (idx_curr_norm_interval.numel()) {
        auto idx_to_device = _move_memory_if_cuda_input(
          idx_curr_norm_interval, a
        );
        auto sub_a = at::index_select(a, 0, idx_to_device);
        res.index_put_({idx_to_device}, compute_Ts[i](sub_a));
      }
    }

    // nonzero returns a 2D tensor, hence squeeze(-1) to make it 1D
    auto idx_large_norm = (norm_cpu >= thetas[total_n_degs - 2])
      .nonzero().squeeze(-1);

    if (idx_large_norm.numel()) {
      auto idx_to_device = _move_memory_if_cuda_input(
        idx_large_norm, a
      );
      auto a_large_norm = at::index_select(a, 0, idx_to_device);
      auto large_norm_subset = at::index_select(norm, 0, idx_to_device);
      auto mexp_out = compute_T18_scale_square(
        a_large_norm,
        large_norm_subset,
        thetas[total_n_degs - 1]
      );
      res.index_put_({idx_large_norm}, mexp_out);
    }
    return res;
  }

  return compute_T18_scale_square(
    a, norm,
    thetas[total_n_degs - 1]
  );
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free