Home / Class/ is_same_v Class — pytorch Architecture

is_same_v Class — pytorch Architecture

Architecture documentation for the is_same_v class in LogSoftmaxKernelImpl.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/LogSoftmaxKernelImpl.h lines 214–336

template <typename scalar_t>
std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
serial_vec_logsoftmax_range(
    const scalar_t* input_data_base,
    scalar_t* output_data_base,
    int64_t inner_size,
    int64_t chunk_size,
    int64_t num_chunks,
    int64_t dim_size,
    int64_t begin,
    int64_t end) {
  using Vec = vec::Vectorized<scalar_t>;
  using fVec = vec::Vectorized<float>;
  auto buffer = std::make_unique<float []>(chunk_size * 2);
  float* input_max_data = buffer.get();
  float* tmp_sum_data = buffer.get() + chunk_size;

  // thread local buffer that holds input data in float32 to save next 2 dtype conversion
  auto input_buffer = std::make_unique<float []>(dim_size * chunk_size);
  float* input_buffer_data = input_buffer.get();

  // init
  for (int64_t i = begin; i < end; i++) {
    int64_t outer_idx = i / num_chunks;
    int64_t k = i % num_chunks;
    int64_t inner_idx_begin = k * chunk_size;
    int64_t size = std::min(chunk_size, inner_size - inner_idx_begin);

    fVec zero_fvec = fVec(float(0));
    fVec min_fvec = fVec(-std::numeric_limits<float>::infinity());
    int64_t d0 = 0;
    for (; d0 < size - (size % Vec::size()); d0 += Vec::size()) {
      min_fvec.store(input_max_data + d0);
      min_fvec.store(input_max_data + d0 + fVec::size());
      zero_fvec.store(tmp_sum_data + d0);
      zero_fvec.store(tmp_sum_data + d0 + fVec::size());
    }
    for (; d0 < size; d0++) {
      input_max_data[d0] = -std::numeric_limits<float>::infinity();
      tmp_sum_data[d0] = float(0);
    }

    // compute max
    for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
      const scalar_t* input_ptr = input_data_base + outer_idx * dim_size * inner_size
          + dim_idx * inner_size + inner_idx_begin;
      float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;

      int64_t d1 = 0;
      for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
        Vec data_vec = Vec::loadu(input_ptr + d1);
        auto [data_fvec0, data_fvec1] = vec::convert_to_float<scalar_t>(data_vec);
        fVec max_fvec0 = fVec::loadu(input_max_data + d1);
        fVec max_fvec1 = fVec::loadu(input_max_data + d1 + fVec::size());
        max_fvec0 = fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0);
        max_fvec1 = fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1);
        max_fvec0.store(input_max_data + d1);
        max_fvec1.store(input_max_data + d1 + fVec::size());

        // cache the 'converted' float input
        data_fvec0.store(input_buffer_ptr + d1);
        data_fvec1.store(input_buffer_ptr + d1 + fVec::size());
      }
      for (; d1 < size; d1++) {
        float data_val = float(input_ptr[d1]);
        float max_val = input_max_data[d1];
        input_max_data[d1] = data_val > max_val ? data_val : max_val;
        input_buffer_ptr[d1] = data_val;
      }
    }

    // compute sum of (x - max).exp()
    for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
      float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;

      int64_t d2 = 0;
      for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
        fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d2);
        fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d2 + fVec::size());
        fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d2);
        fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d2 + fVec::size());
        fVec max_fvec0 = fVec::loadu(input_max_data + d2);
        fVec max_fvec1 = fVec::loadu(input_max_data + d2 + fVec::size());
        sum_fvec0 += (data_fvec0 - max_fvec0).exp();
        sum_fvec1 += (data_fvec1 - max_fvec1).exp();
        sum_fvec0.store(tmp_sum_data + d2);
        sum_fvec1.store(tmp_sum_data + d2 + fVec::size());
      }
      for (; d2 < size; d2++) {
        float data_val = input_buffer_ptr[d2];
        float max_val = input_max_data[d2];
        tmp_sum_data[d2] += std::exp(data_val - max_val);
      }
    }

    // apply log
    vec::map([](fVec x) { return x.log(); }, tmp_sum_data, tmp_sum_data, size);

    // compute x - max - sum
    for (int64_t dim_idx = 0; dim_idx < dim_size; dim_idx++) {
      float* input_buffer_ptr = input_buffer_data + dim_idx * chunk_size;
      scalar_t* output_ptr = output_data_base + outer_idx * dim_size * inner_size
          + dim_idx * inner_size + inner_idx_begin;

      int64_t d3 = 0;
      for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
        fVec data_fvec0 = fVec::loadu(input_buffer_ptr + d3);
        fVec data_fvec1 = fVec::loadu(input_buffer_ptr + d3 + fVec::size());
        fVec max_fvec0 = fVec::loadu(input_max_data + d3);
        fVec max_fvec1 = fVec::loadu(input_max_data + d3 + fVec::size());
        fVec sum_fvec0 = fVec::loadu(tmp_sum_data + d3);
        fVec sum_fvec1 = fVec::loadu(tmp_sum_data + d3 + fVec::size());
        fVec out_fvec0 = data_fvec0 - max_fvec0 - sum_fvec0;
        fVec out_fvec1 = data_fvec1 - max_fvec1 - sum_fvec1;
        Vec out_vec = vec::convert_from_float<scalar_t>(out_fvec0, out_fvec1);
        out_vec.store(output_ptr + d3);
      }
      for (; d3 < size; d3++) {
        output_ptr[d3] = scalar_t(input_buffer_ptr[d3] - input_max_data[d3] - tmp_sum_data[d3]);
      }
    }
  }
} // namespace CPU_CAPABILITY

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free