tensor_dim_apply3 Class — pytorch Architecture
Architecture documentation for the tensor_dim_apply3 class in TensorDimApply.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/TensorDimApply.h lines 7–66
template <typename T1, typename T2, typename Function>
void tensor_dim_apply3(
const Tensor& self,
Tensor& values,
Tensor& indices,
int64_t dim,
Function func) {
int ndims = self.dim();
int tensor_dim_apply_has_finished = 0;
std::vector<int64_t> counter(ndims, 0);
const T1* self_data = self.const_data_ptr<T1>();
T1* values_data = values.data_ptr<T1>();
T2* indices_data = indices.data_ptr<T2>();
int64_t self_stride = self.stride(dim);
int64_t values_stride = values.stride(dim);
int64_t indices_stride = indices.stride(dim);
int self_dim_size = self.size(dim);
while (!tensor_dim_apply_has_finished) {
func(
self_data,
values_data,
indices_data,
self_dim_size,
self_stride,
values_stride,
indices_stride);
if (ndims == 1) {
break;
}
for (const auto dim_i : c10::irange(ndims)) {
if (dim_i == dim) {
if (dim_i == (ndims - 1)) {
tensor_dim_apply_has_finished = 1;
break;
}
continue;
}
counter[dim_i]++;
self_data += self.stride(dim_i);
values_data += values.stride(dim_i);
indices_data += indices.stride(dim_i);
if (counter[dim_i] == self.size(dim_i)) {
if (dim_i == ndims - 1) {
tensor_dim_apply_has_finished = 1;
break;
} else {
self_data -= counter[dim_i] * self.stride(dim_i);
values_data -= counter[dim_i] * values.stride(dim_i);
indices_data -= counter[dim_i] * indices.stride(dim_i);
counter[dim_i] = 0;
}
} else {
break;
}
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free