Home / Class/ compute_T18_scale_square Class — pytorch Architecture

compute_T18_scale_square Class — pytorch Architecture

Architecture documentation for the compute_T18_scale_square class in LinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebra.cpp lines 2572–2640

template <typename scalar_t>
Tensor compute_T18_scale_square(
  const Tensor& a,
  const Tensor& norm,
  scalar_t theta
) {
  // Scale
  // We eventually need to do the matrix multiplication to calculate the result.
  // For example, if we have `norm` equal to [27, 6, 6, 0.05], we will end up to
  // get `s` as [4, 1, 1, 0], so we can use it to get the result by calculating
  // matrix[0]^(2^4), matrix[1]^(2^1) and matrix[2]^(2^1) one by one to get the
  // result, such "one by one calculation" will be quite slow.
  const auto s = (at::ceil(at::log2(norm / theta))).clamp(/*min=*/0);
  const auto pow2s = at::pow(2, -s);
  const auto a_scaled = a * pow2s.view({-1, 1, 1});
  auto mexp_scaled = at::native::compute_T18<scalar_t>(a_scaled);

  // Sort:
  // Consider inputs are square matrix, so if we first power `matrix 0,1,2`, then
  // the remain thing will only be multiply `matrix 0` by (2^4 - 1) times, which
  // gives us an opportunity to calculate the matrix multiplication in a batch.
  // The first thing we need to do is sort tensor `s`, which will be helpful to
  // do the matrix multiplication by range.
  // With above example, `sorted_s` is [0, 1, 1, 4], we also will need the index
  // info, so we can use it to compose the result back.
  auto [sorted_s, sorted_s_inds] = at::sort(s, /*dim=*/0);
  sorted_s = sorted_s.to(at::kLong);
  // Then we call `unique_consecutive` and we will use it to split `sorted_s`,
  // with above example, `split_counts` is [1, 2, 1].
  auto split_counts = std::get<2>(at::unique_consecutive(sorted_s, true, /*return_counts=*/true));
  // We also need to know the index of the last element of each split, so we can
  // know how many times we need to do the multiplication for each split matrix.
  // Notice that, we will not need to calculate the actual pows, because we will
  // use the cumulative matrix multiplication.
  // With about example, `mul_times` will be [0, 1, 3].
  auto split_edges = at::cumsum(split_counts, /*dim=*/0) - 1;
  auto unique_s = sorted_s.index_select(0, split_edges).clamp(/*min=*/0);
  auto mul_times = at::diff(unique_s, 1, -1, /*prepend=*/unique_s.new_zeros({1}));

  // Square
  auto section_values = at::cat({split_counts, mul_times}, 0).to(at::kCPU);

  TORCH_INTERNAL_ASSERT(section_values.is_contiguous());
  const auto section_numel = section_values.numel() / 2;
  auto scs = section_values. template data_ptr<int64_t>();
  auto pts = &scs[section_numel];

  // We now will do the matrix multiplication in a batch, with above example:
  // 1. Multiply all matrices by 0 (`mul_times[0]`) times, then do `slice`
  // to get the remain matrices by acc[1:] (`split_counts[0]`),
  // 2. Multiply remain matrices by 1 times and slice to acc[2:]
  // 3. Multiply remain matrices by 3 times and slice to acc[1:]
  // All processed matrices will be stored in `output_pieces`.
  std::vector<Tensor> output_pieces;
  auto acc = mexp_scaled.index_select(0, sorted_s_inds);
  for (int64_t i = 0; i < section_numel; ++i) {
    for (int64_t j = 0; j < pts[i]; j++) {
      // To avoid AMP autocasting caused by at::matmul
      auto acc_out = at::empty_like(acc);
      acc = at::matmul_out(acc_out, acc, acc);
    }
    output_pieces.push_back(acc.slice(0, 0, scs[i]));
    acc = acc.slice(0, scs[i]);
  }

  // Compose the result back
  auto output = at::cat(output_pieces, 0);
  return output.index_select(0, at::argsort(sorted_s_inds));
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free