weight_index_stride Class — pytorch Architecture
Architecture documentation for the weight_index_stride class in UpSampleKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/UpSampleKernel.cpp lines 857–942
template <typename scalar_t, typename aa_filter_fn_t, int weight_index_stride=sizeof(scalar_t)>
static inline std::tuple<std::vector<Tensor>, int, scalar_t> _compute_index_ranges_weights(
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
int64_t reshape_dim, scalar_t scale,
int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, bool align_corners
) {
std::vector<Tensor> output;
scalar_t support;
int max_interp_size = 0;
if (antialias) {
support = (scale >= 1.0) ? (interp_size * 0.5) * scale : interp_size * 0.5;
max_interp_size = (int) std::ceil(support) * 2 + 1;
} else {
support = interp_size * 0.5;
max_interp_size = interp_size;
}
auto new_shape = std::vector<int64_t>(ndims, 1);
new_shape[reshape_dim] = output_size;
// Bounds approach as in PIL: xmin/xmax
output.emplace_back(
empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));
output.emplace_back(
empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));
output.emplace_back(
empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));
{
// Weights
new_shape[reshape_dim] = output_size * max_interp_size;
auto wts = empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<scalar_t>()));
auto strides = wts.strides().vec();
strides[reshape_dim] = 0;
new_shape[reshape_dim] = output_size;
wts = wts.as_strided(new_shape, strides);
output.emplace_back(wts);
// Weights indices
output.emplace_back(
empty(new_shape, at::device(kCPU).dtype(c10::CppTypeToScalarType<int64_t>())));
}
int64_t* idx_ptr_xmin = output[0].data_ptr<int64_t>();
int64_t* idx_ptr_size = output[1].data_ptr<int64_t>();
int64_t* idx_ptr_stride = output[2].data_ptr<int64_t>();
scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
int64_t* wt_idx_ptr = output[4].data_ptr<int64_t>();
scalar_t wt_max = 0.0;
for (const auto i : c10::irange(output_size)) {
int64_t xmin = 0, xsize = 0;
scalar_t wt_max_i;
if (antialias) {
wt_max_i = HelperInterpBase::_compute_indices_min_size_weights_aa(
i,
input_size,
scale,
support,
wt_ptr + i * max_interp_size,
max_interp_size,
aa_filter_fn,
xmin,
xsize);
} else {
wt_max_i = HelperInterpBase::_compute_indices_min_size_weights(
i,
input_size,
scale,
wt_ptr + i * max_interp_size,
max_interp_size,
aa_filter_fn,
align_corners,
xmin,
xsize);
}
wt_max = std::max(wt_max, wt_max_i);
idx_ptr_xmin[i] = xmin * stride;
idx_ptr_size[i] = xsize;
idx_ptr_stride[i] = stride;
wt_idx_ptr[i] = i * max_interp_size * weight_index_stride;
}
return {output, max_interp_size, wt_max};
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free