Home / Class/ computeStride_impl Class — pytorch Architecture

computeStride_impl Class — pytorch Architecture

Architecture documentation for the computeStride_impl class in TensorUtils.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/TensorUtils.cpp lines 326–407

template <typename ResultVec, typename NewShapeVec, typename Numel>
inline static std::optional<ResultVec> computeStride_impl(
    const NewShapeVec& oldshape,
    const NewShapeVec& oldstride,
    const NewShapeVec& newshape,
    ResultVec toResult(const NewShapeVec&)
) {
  if (oldshape.empty()) {
    return ResultVec(newshape.size(), 1);
  }

  // NOTE: stride is arbitrary in the numel() == 0 case;
  // to match NumPy behavior we copy the strides if the size matches, otherwise
  // we use the stride as if it were computed via resize.
  // This could perhaps be combined with the below code, but the complexity
  // didn't seem worth it.
  const Numel numel = c10::multiply_integers(oldshape);
  bool zero_numel = TORCH_GUARD_OR_FALSE(sym_eq(numel, 0));
  if (zero_numel && oldshape.equals(newshape)) {
    return toResult(oldstride);
  }

  ResultVec newstride(newshape.size());
  if (zero_numel) {
    for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) {
      if (view_d == (int64_t)(newshape.size() - 1)) {
        newstride[view_d] = 1;
      } else {
        newstride[view_d] =
          std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1];
      }
    }
    return newstride;
  }

  int64_t view_d = (int64_t)newshape.size() - 1;
  // stride for each subspace in the chunk
  Numel chunk_base_stride = oldstride.back();
  // numel in current chunk
  Numel tensor_numel = 1;
  Numel view_numel = 1;

 // The usages of TORCH_GUARD_OR_TRUE/TORCH_GUARD_OR_FALSE below could result in returning
 // std::nullopt which has an effect of falling back to a clone when unbacked symints are present.
 // But it will not result in returning different or wrong results.
  for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
    tensor_numel *= oldshape[tensor_d];
    // if end of tensor size chunk, check view
    if ((tensor_d == 0) ||
        (TORCH_GUARD_OR_TRUE(sym_ne(oldshape[tensor_d - 1], 1)) &&
        TORCH_GUARD_OR_TRUE(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
     // We want to accumulate stuff in view_numel until view_numel == tensor_numel, if we do not
     // know if that is satisfied we keep accumulating. For example if view_numel = 1 and tensor_numel = u1,
     // we want to take that path, view_numel will become u0. Next iteration if u0==u1 we want to stop.
     // That's why we use TORCH_GUARD_OR_TRUE below.

     // we use TORCH_GUARD_OR_FALSE and not TORCH_GUARD_OR_TRUE when comparing newshape[view_d] ==1 because
     // if we know view_numel < tensor_numel is false, we want to stop. Unless we know for sure newshape[view_d]==1
     // in that case we would stop in the next iteration anyway. For example, if view_numel = u0 and tensor_numel = u1,
     // and u0==u1, then want to stop unless newshape[view_d]==1. taking one more iteration will keep [view_numel = u0
     // and tensor_numel = u1].
      while (view_d >= 0 &&
            (TORCH_GUARD_OR_TRUE(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_OR_FALSE(sym_eq(newshape[view_d], 1)))) {
        newstride[view_d] = view_numel * chunk_base_stride;
        view_numel *= newshape[view_d];
        view_d--;
      }
      if (TORCH_GUARD_OR_TRUE(sym_ne(view_numel, tensor_numel))) {
        return std::nullopt;
      }
      if (tensor_d > 0) {
        chunk_base_stride = oldstride[tensor_d - 1];
        tensor_numel = 1;
        view_numel = 1;
      }
    }
  }
  if (view_d != -1) {
    return std::nullopt;
  }
  return newstride;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free