Home / Class/ is_offsets_like Class — pytorch Architecture

is_offsets_like Class — pytorch Architecture

Architecture documentation for the is_offsets_like class in SegmentReduce.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/SegmentReduce.cpp lines 30–128

template <typename T, bool is_offsets_like=false>
void _segment_reduce_lengths_cpu_kernel1(
    ReductionType reduction,
    const Tensor& data,
    const T* lengths_data,
    int64_t axis,
    const std::optional<Scalar>& initial,
    Tensor& output,
    int64_t segment_count,
    int64_t lengths_stride_axis) {
  // outer_offset is the size of the outer dimensions of output (before axis)
  // inner_offset is the size of the inner dimensions of output (after axis)
  int64_t outer_offset = 1, inner_offset = 1;
  for (int64_t d = 0; d < axis; d++)
      outer_offset *= output.size(d);
  for (int64_t d = axis + 1; d < output.dim(); d++)
      inner_offset *= output.size(d);
  int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
  auto data_stride_axis = data.stride(axis);
  auto data_size_axis = data.size(axis);
  auto output_stride_axis = output.stride(axis);
  auto output_size_axis = output.size(axis);
  AT_DISPATCH_FLOATING_TYPES_AND2(
      kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() {
        auto* output_data = output.data_ptr<scalar_t>();
        const auto* values_data = data.const_data_ptr<scalar_t>();
        for (const auto outer_idx : c10::irange(outer_offset)) {
          int64_t segment_start, segment_length;
          int64_t segment_end = is_offsets_like ?
                                lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
                                0;
          for (const auto dim_idx : c10::irange(segment_count)) {
            segment_start = segment_end;
            auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
            if (is_offsets_like) {
              segment_end = lengths_data[lengths_idx + 1];
              segment_length = segment_end - segment_start;
            } else {
              segment_length = lengths_data[lengths_idx];
              segment_end += segment_length;
            }
            for (const auto inner_idx : c10::irange(inner_offset)) {
              // ===== step1: initialize starting value
              scalar_t initial_value;
              if (initial.has_value()) {
                initial_value = initial.value().to<scalar_t>();
              } else if (reduction == ReductionType::MAX) {
                initial_value = -std::numeric_limits<scalar_t>::infinity();
              } else if (
                  reduction == ReductionType::MEAN ||
                  reduction == ReductionType::SUM) {
                initial_value = 0;
              } else if (reduction == ReductionType::MIN) {
                initial_value = std::numeric_limits<scalar_t>::infinity();
              } else if (reduction == ReductionType::PROD) {
                initial_value = 1;
              }

              // ===== step2: apply reduction
              for (const auto j : c10::irange(segment_start, segment_end)) {
                int64_t data_index = outer_idx * data_stride_axis * data_size_axis
                                     + j * data_stride_axis + inner_idx;
                const auto val = values_data[data_index];
                if (reduction == ReductionType::MAX) {
                  initial_value = at::_isnan(val)
                      ? val
                      : std::max<scalar_t>(initial_value, val);
                } else if (
                    reduction == ReductionType::MEAN ||
                    reduction == ReductionType::SUM) {
                  initial_value = initial_value + val;
                } else if (reduction == ReductionType::MIN) {
                  initial_value = at::_isnan(val)
                      ? val
                      : std::min<scalar_t>(initial_value, val);
                } else if (reduction == ReductionType::PROD) {
                  initial_value = initial_value * val;
                }
              }

              // ===== step3: finalize reduction
              TORCH_CHECK(segment_length >= 0);

              if (segment_length == 0 && !initial.has_value() &&
                  reduction == ReductionType::MEAN) {
                initial_value = static_cast<scalar_t>(NAN);
              } else if (
                  reduction == ReductionType::MEAN &&
                  segment_length > 0 && !at::_isnan(initial_value)) {
                initial_value = initial_value / segment_length;
              }
              int64_t output_index = outer_idx * output_stride_axis * output_size_axis
                                     + dim_idx * output_stride_axis + inner_idx;
              output_data[output_index] = initial_value;
            }
          }
        }
      });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free