computeStride_impl Class — pytorch Architecture
Architecture documentation for the computeStride_impl class in TensorUtils.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/TensorUtils.cpp lines 326–407
template <typename ResultVec, typename NewShapeVec, typename Numel>
inline static std::optional<ResultVec> computeStride_impl(
const NewShapeVec& oldshape,
const NewShapeVec& oldstride,
const NewShapeVec& newshape,
ResultVec toResult(const NewShapeVec&)
) {
if (oldshape.empty()) {
return ResultVec(newshape.size(), 1);
}
// NOTE: stride is arbitrary in the numel() == 0 case;
// to match NumPy behavior we copy the strides if the size matches, otherwise
// we use the stride as if it were computed via resize.
// This could perhaps be combined with the below code, but the complexity
// didn't seem worth it.
const Numel numel = c10::multiply_integers(oldshape);
bool zero_numel = TORCH_GUARD_OR_FALSE(sym_eq(numel, 0));
if (zero_numel && oldshape.equals(newshape)) {
return toResult(oldstride);
}
ResultVec newstride(newshape.size());
if (zero_numel) {
for (int64_t view_d = newshape.size() - 1; view_d >= 0; view_d--) {
if (view_d == (int64_t)(newshape.size() - 1)) {
newstride[view_d] = 1;
} else {
newstride[view_d] =
std::max<Numel>(newshape[view_d+1], Numel(1)) * newstride[view_d+1];
}
}
return newstride;
}
int64_t view_d = (int64_t)newshape.size() - 1;
// stride for each subspace in the chunk
Numel chunk_base_stride = oldstride.back();
// numel in current chunk
Numel tensor_numel = 1;
Numel view_numel = 1;
// The usages of TORCH_GUARD_OR_TRUE/TORCH_GUARD_OR_FALSE below could result in returning
// std::nullopt which has an effect of falling back to a clone when unbacked symints are present.
// But it will not result in returning different or wrong results.
for (int64_t tensor_d = oldshape.size() - 1; tensor_d >= 0; tensor_d--) {
tensor_numel *= oldshape[tensor_d];
// if end of tensor size chunk, check view
if ((tensor_d == 0) ||
(TORCH_GUARD_OR_TRUE(sym_ne(oldshape[tensor_d - 1], 1)) &&
TORCH_GUARD_OR_TRUE(sym_ne(oldstride[tensor_d - 1], tensor_numel * chunk_base_stride)))) {
// We want to accumulate stuff in view_numel until view_numel == tensor_numel, if we do not
// know if that is satisfied we keep accumulating. For example if view_numel = 1 and tensor_numel = u1,
// we want to take that path, view_numel will become u0. Next iteration if u0==u1 we want to stop.
// That's why we use TORCH_GUARD_OR_TRUE below.
// we use TORCH_GUARD_OR_FALSE and not TORCH_GUARD_OR_TRUE when comparing newshape[view_d] ==1 because
// if we know view_numel < tensor_numel is false, we want to stop. Unless we know for sure newshape[view_d]==1
// in that case we would stop in the next iteration anyway. For example, if view_numel = u0 and tensor_numel = u1,
// and u0==u1, then want to stop unless newshape[view_d]==1. taking one more iteration will keep [view_numel = u0
// and tensor_numel = u1].
while (view_d >= 0 &&
(TORCH_GUARD_OR_TRUE(sym_lt(view_numel, tensor_numel)) || TORCH_GUARD_OR_FALSE(sym_eq(newshape[view_d], 1)))) {
newstride[view_d] = view_numel * chunk_base_stride;
view_numel *= newshape[view_d];
view_d--;
}
if (TORCH_GUARD_OR_TRUE(sym_ne(view_numel, tensor_numel))) {
return std::nullopt;
}
if (tensor_d > 0) {
chunk_base_stride = oldstride[tensor_d - 1];
tensor_numel = 1;
view_numel = 1;
}
}
}
if (view_d != -1) {
return std::nullopt;
}
return newstride;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free