Home / Class/ weight_index_stride Class — pytorch Architecture

weight_index_stride Class — pytorch Architecture

Architecture documentation for the weight_index_stride class in UpSampleKernel.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cpu/UpSampleKernel.cpp lines 857–942

  template <typename scalar_t, typename aa_filter_fn_t, int weight_index_stride=sizeof(scalar_t)>
  static inline std::tuple<std::vector<Tensor>, int, scalar_t> _compute_index_ranges_weights(
    int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
    int64_t reshape_dim, scalar_t scale,
    int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, bool align_corners
  ) {

    std::vector<Tensor> output;

    scalar_t support;
    int max_interp_size = 0;
    if (antialias) {
        support = (scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5;
        max_interp_size = (int) std::ceil(support) * 2 + 1;
    } else {
        support = interp_size * 0.5;
        max_interp_size = interp_size;
    }

    auto new_shape = std::vector<int64_t>(ndims, 1);
    new_shape[reshape_dim] = output_size;

    // Bounds approach as in PIL: xmin/xmax
    output.emplace_back(
        empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));
    output.emplace_back(
        empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));
    output.emplace_back(
        empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));

    {
      // Weights
      new_shape[reshape_dim] = output_size * max_interp_size;
      auto wts = empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<scalar_t>()));
      auto strides = wts.strides().vec();
      strides[reshape_dim] = 0;
      new_shape[reshape_dim] = output_size;
      wts = wts.as_strided(new_shape, strides);
      output.emplace_back(wts);
      // Weights indices
      output.emplace_back(
          empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));
    }

    int64_t* idx_ptr_xmin = output[0].data_ptr<int64_t>();
    int64_t* idx_ptr_size = output[1].data_ptr<int64_t>();
    int64_t* idx_ptr_stride = output[2].data_ptr<int64_t>();
    scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
    int64_t* wt_idx_ptr = output[4].data_ptr<int64_t>();

    scalar_t wt_max = 0.0;
    for (const auto i : c10::irange(output_size)) {
      int64_t xmin = 0, xsize = 0;
      scalar_t wt_max_i;
      if (antialias) {
        wt_max_i = HelperInterpBase::_compute_indices_min_size_weights_aa(
            i,
            input_size,
            scale,
            support,
            wt_ptr + i * max_interp_size,
            max_interp_size,
            aa_filter_fn,
            xmin,
            xsize);
      } else {
        wt_max_i = HelperInterpBase::_compute_indices_min_size_weights(
            i,
            input_size,
            scale,
            wt_ptr + i * max_interp_size,
            max_interp_size,
            aa_filter_fn,
            align_corners,
            xmin,
            xsize);
      }
      wt_max = std::max(wt_max, wt_max_i);

      idx_ptr_xmin[i] = xmin * stride;
      idx_ptr_size[i] = xsize;
      idx_ptr_stride[i] = stride;
      wt_idx_ptr[i] = i * max_interp_size * weight_index_stride;
    }
    return {output, max_interp_size, wt_max};
  }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free