Home / Class/ scalar_t Class — pytorch Architecture

scalar_t Class — pytorch Architecture

Architecture documentation for the scalar_t class in LogSoftmaxKernelImpl.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h lines 115–212

template <typename scalar_t>
std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
serial_vec_logsoftmax_range(
    const scalar_t* input_data_base,
    scalar_t* output_data_base,
    int64_t inner_size,
    int64_t chunk_size,
    int64_t num_chunks,
    int64_t dim_size,
    int64_t begin,
    int64_t end) {
  using Vec = vec::Vectorized<scalar_t>;
  // thread local temp buffer which holds vertical reduction result: max and sum.
  auto buffer = std::make_unique<scalar_t []>(chunk_size * 2);
  scalar_t* input_max_data = buffer.get();
  scalar_t* tmp_sum_data = buffer.get() + chunk_size;

  for (int64_t i = begin; i < end; i++) {
    int64_t outer_idx = i / num_chunks;
    int64_t k = i % num_chunks;
    int64_t inner_idx_begin = k * chunk_size;
    int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);

    // init
    Vec zero_vec = Vec(scalar_t(0));
    Vec min_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
    int64_t d0 = 0;
    for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
      min_vec.store(input_max_data + d0);
      zero_vec.store(tmp_sum_data + d0);
    }
    for (; d0 < size; d0++) {
      input_max_data[d0] = -std::numeric_limits<scalar_t>::infinity();
      tmp_sum_data[d0] = scalar_t(0);
    }

    // compute max
    for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
      const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
          + dim_idx * inner_size + inner_idx_begin;

      int64_t d1 = 0;
      for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
        Vec data_vec = Vec::loadu(input_ptr + d1);
        Vec max_vec = Vec::loadu(input_max_data + d1);
        max_vec = Vec::blendv(max_vec, data_vec, data_vec > max_vec);
        max_vec.store(input_max_data + d1);
      }
      for (; d1 < size; d1++) {
        scalar_t data_val = input_ptr[d1];
        scalar_t max_val = input_max_data[d1];
        input_max_data[d1] = data_val > max_val ? data_val : max_val;
      }
    }

    // compute sum of (x - max).exp()
    for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
      const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
          + dim_idx * inner_size + inner_idx_begin;

      int64_t d2 = 0;
      for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
        Vec data_vec = Vec::loadu(input_ptr + d2);
        Vec sum_vec = Vec::loadu(tmp_sum_data + d2);
        Vec max_vec = Vec::loadu(input_max_data + d2);
        sum_vec += (data_vec - max_vec).exp();
        sum_vec.store(tmp_sum_data + d2);
      }
      for (; d2 < size; d2++) {
        scalar_t data_val = input_ptr[d2];
        scalar_t max_val = input_max_data[d2];
        tmp_sum_data[d2] += std::exp(data_val - max_val);
      }
    }

    // apply log
    vec::map([](Vec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);

    // compute x - max - sum
    for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
      int64_t offset = outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx_begin;
      const scalar_t* input_ptr = input_data_base + offset;
      scalar_t* output_ptr = output_data_base + offset;

      int64_t d3 = 0;
      for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
        Vec data_vec = Vec::loadu(input_ptr + d3);
        Vec max_vec = Vec::loadu(input_max_data + d3);
        Vec sum_vec = Vec::loadu(tmp_sum_data + d3);
        Vec out_vec = data_vec - max_vec - sum_vec;
        out_vec.store(output_ptr + d3);
      }
      for (; d3 < size; d3++) {
        output_ptr[d3] = input_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3];
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free