compute_T18_scale_square Class — pytorch Architecture
Architecture documentation for the compute_T18_scale_square class in LinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LinearAlgebra.cpp lines 2572–2640
template <typename scalar_t>
Tensor compute_T18_scale_square(
const Tensor& a,
const Tensor& norm,
scalar_t theta
) {
// Scale
// We eventually need to do the matrix multiplication to calculate the result.
// For example, if we have `norm` equal to [27, 6, 6, 0.05], we will end up to
// get `s` as [4, 1, 1, 0], so we can use it to get the result by calculating
// matrix[0]^(2^4), matrix[1]^(2^1) and matrix[2]^(2^1) one by one to get the
// result, such "one by one calculation" will be quite slow.
const auto s = (at::ceil(at::log2(norm / theta))).clamp(/*min=*/0);
const auto pow2s = at::pow(2, -s);
const auto a_scaled = a * pow2s.view({-1, 1, 1});
auto mexp_scaled = at::native::compute_T18<scalar_t>(a_scaled);
// Sort:
// Consider inputs are square matrix, so if we first power `matrix 0,1,2`, then
// the remain thing will only be multiply `matrix 0` by (2^4 - 1) times, which
// gives us an opportunity to calculate the matrix multiplication in a batch.
// The first thing we need to do is sort tensor `s`, which will be helpful to
// do the matrix multiplication by range.
// With above example, `sorted_s` is [0, 1, 1, 4], we also will need the index
// info, so we can use it to compose the result back.
auto [sorted_s, sorted_s_inds] = at::sort(s, /*dim=*/0);
sorted_s = sorted_s.to(at::kLong);
// Then we call `unique_consecutive` and we will use it to split `sorted_s`,
// with above example, `split_counts` is [1, 2, 1].
auto split_counts = std::get<2>(at::unique_consecutive(sorted_s, true, /*return_counts=*/true));
// We also need to know the index of the last element of each split, so we can
// know how many times we need to do the multiplication for each split matrix.
// Notice that, we will not need to calculate the actual pows, because we will
// use the cumulative matrix multiplication.
// With about example, `mul_times` will be [0, 1, 3].
auto split_edges = at::cumsum(split_counts, /*dim=*/0) - 1;
auto unique_s = sorted_s.index_select(0, split_edges).clamp(/*min=*/0);
auto mul_times = at::diff(unique_s, 1, -1, /*prepend=*/unique_s.new_zeros({1}));
// Square
auto section_values = at::cat({split_counts, mul_times}, 0).to(at::kCPU);
TORCH_INTERNAL_ASSERT(section_values.is_contiguous());
const auto section_numel = section_values.numel() / 2;
auto scs = section_values. template data_ptr<int64_t>();
auto pts = &scs[section_numel];
// We now will do the matrix multiplication in a batch, with above example:
// 1. Multiply all matrices by 0 (`mul_times[0]`) times, then do `slice`
// to get the remain matrices by acc[1:] (`split_counts[0]`),
// 2. Multiply remain matrices by 1 times and slice to acc[2:]
// 3. Multiply remain matrices by 3 times and slice to acc[1:]
// All processed matrices will be stored in `output_pieces`.
std::vector<Tensor> output_pieces;
auto acc = mexp_scaled.index_select(0, sorted_s_inds);
for (int64_t i = 0; i < section_numel; ++i) {
for (int64_t j = 0; j < pts[i]; j++) {
// To avoid AMP autocasting caused by at::matmul
auto acc_out = at::empty_like(acc);
acc = at::matmul_out(acc_out, acc, acc);
}
output_pieces.push_back(acc.slice(0, 0, scs[i]));
acc = acc.slice(0, scs[i]);
}
// Compose the result back
auto output = at::cat(output_pieces, 0);
return output.index_select(0, at::argsort(sorted_s_inds));
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free