_compute_index_ranges_int16_weights Class — pytorch Architecture
Architecture documentation for the _compute_index_ranges_int16_weights class in UpSampleKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/UpSampleKernel.cpp lines 985–1032
template <typename aa_filter_fn_t>
static inline std::tuple<std::vector<Tensor>, int, unsigned int> _compute_index_ranges_int16_weights(
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
int64_t reshape_dim, bool align_corners, const std::optional<double>& opt_scale,
int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, bool align_i32=false
) {
double scale = area_pixel_compute_scale<double>(
input_size, output_size, align_corners, opt_scale);
auto [indices_weights, aligned_interp_size, wt_max] = HelperInterpBase::_compute_index_ranges_weights<double, aa_filter_fn_t, sizeof(int16_t)>(
input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners);
interp_size = aligned_interp_size;
// Rescale float weights to int16 and compute weights precision
auto weights_f64 = indices_weights[3];
double * data_f64 = weights_f64. template data_ptr<double>();
unsigned int weights_precision = 0;
for (weights_precision = 0; weights_precision < 22; ++weights_precision) {
int next_value = (int) (0.5 + wt_max * (1 << (weights_precision + 1)));
if (next_value >= (1 << 15))
break;
}
// Rescale float values to int16
int16_t * data_i16 = (int16_t *) data_f64;
if (align_i32) {
// We should respect int32 alignment as we will load int16 data as int32
// See ImagingResampleHorizontalConvolution8u4x, mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]);
// compute aligned_interp_size = nearest pair value to interp_size
while (aligned_interp_size % sizeof(int32_t) != 0) {
aligned_interp_size += 1;
}
// assert that we won't go out of bounds
TORCH_INTERNAL_ASSERT(aligned_interp_size * sizeof(int16_t) < interp_size * sizeof(double));
}
for (const auto j : c10::irange(output_size)) {
for (const auto k : c10::irange(interp_size)) {
double v = data_f64[j * interp_size + k] * (1 << weights_precision);
data_i16[j * aligned_interp_size + k] = (v < 0) ? (int) (-0.5 + v) : (int) (0.5 + v);
}
}
return {indices_weights, aligned_interp_size, weights_precision};
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free