_tensor_split_indices Class — pytorch Architecture
Architecture documentation for the _tensor_split_indices class in TensorShape.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/TensorShape.cpp lines 1129–1151
template <typename T>
static std::vector<Tensor> _tensor_split_indices(
const Tensor& self,
ArrayRef<T> indices,
int64_t dim) {
TORCH_CHECK(
self.dim() > 0,
"tensor_split expected at least a 1-dimensional tensor, but got a tensor with ",
self.dim(),
" dims");
int64_t dim_ = maybe_wrap_dim(dim, self.dim());
int64_t num_indices = indices.size();
std::vector<Tensor> splits(num_indices + 1);
T start_idx(0);
for (const auto split_idx : c10::irange(num_indices)) {
auto end_idx = indices[split_idx];
splits[split_idx] = at::symint::slice<T>(self, dim_, start_idx, end_idx);
start_idx = end_idx;
}
splits[num_indices] = at::symint::slice<T>(
self, dim_, start_idx, at::symint::size<T>(self, dim_));
return splits;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free