Home / Class/ host_softmax Class — pytorch Architecture

host_softmax Class — pytorch Architecture

Architecture documentation for the host_softmax class in SoftMax.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/SoftMax.cpp lines 152–243

template <typename scalar_t>
void host_softmax(
    Tensor& output,
    const Tensor& input,
    const int64_t dim,
    bool* mask,
    const std::optional<int64_t> mask_type_) {

  TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined");
  int64_t mask_type = mask_type_.value();
  // If mask_type == 2, then mask_.sizes() must equal input_.sizes()
  TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)");

  int64_t outer_size = 1;
  int64_t dim_size = input.size(dim);
  int64_t inner_size = 1;
  for (const auto i : c10::irange(dim)) {
    outer_size *= input.size(i);
  }
  for (int64_t i = dim + 1; i < input.dim(); ++i) {
    inner_size *= input.size(i);
  }
  int64_t dim_stride = inner_size;
  int64_t outer_stride = dim_size * dim_stride;
  scalar_t* input_data_base = input.data_ptr<scalar_t>();
  scalar_t* output_data_base = output.data_ptr<scalar_t>();
  bool* mask_data_base = mask;
  int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
  parallel_for(
      0, outer_size * inner_size, grain_size,
      [&](int64_t begin, int64_t end) {
        for (const auto i : c10::irange(begin, end)) {
          int64_t outer_idx = i / inner_size;
          int64_t inner_idx = i % inner_size;
          scalar_t* input_data =
              input_data_base + outer_idx * outer_stride + inner_idx;
          scalar_t* output_data =
              output_data_base + outer_idx * outer_stride + inner_idx;
          // Process mask differently depending on the type:
          // For a generic mask of mask_type == 2, mask shape is the same as the input shape,
          // so indexing is the same.
          auto mask_outer_idx = outer_idx;
          if (mask_type_ == 0) {
              // Optimized case: attention mask of shape LxL
              // outer_idx goes over BxHxL, mask_outer_idx goes over L.
              mask_outer_idx = outer_idx % input.size(2);
          } else if (mask_type_ == 1) {
              // Optimized case: padding mask of shape BxL
              // outer_idx goes over BxHxL, mask_outer_idx goes over B.
              mask_outer_idx = outer_idx / (input.size(1) * input.size(2));
          }

          bool* mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx;

          // Calc max in softmax dim
          bool is_meaningful_max = false;
          scalar_t max_input = input_data[0];
          for (const auto d : c10::irange(0, dim_size)) {
            if (!mask_data[d * dim_stride]) {
              max_input = is_meaningful_max
                  ? std::max(max_input, input_data[d * dim_stride])
                  : input_data[d * dim_stride];
              is_meaningful_max = true;
            }
          }

          // Calc sum in softmax dim
          acc_type<scalar_t, false> tmpsum = 0;
          for (const auto d : c10::irange(dim_size)) {
            scalar_t z{};
            if (!mask_data[d * dim_stride]) {
              z = std::exp(input_data[d * dim_stride] - max_input);
            } else {
              z = 0;
            }
            output_data[d * dim_stride] = z;
            tmpsum += z;
          }

          if (tmpsum == 0) {
            tmpsum = std::numeric_limits<scalar_t>::quiet_NaN();
          } else {
            tmpsum = 1 / tmpsum;
          }

          // update output
          for (const auto d : c10::irange(dim_size)) {
            output_data[d * dim_stride] *= tmpsum;
          }
        }
      });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free