Home / Class/ _tensor_split_indices Class — pytorch Architecture

_tensor_split_indices Class — pytorch Architecture

Architecture documentation for the _tensor_split_indices class in TensorShape.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/TensorShape.cpp lines 1129–1151

template <typename T>
static std::vector<Tensor> _tensor_split_indices(
    const Tensor& self,
    ArrayRef<T> indices,
    int64_t dim) {
  TORCH_CHECK(
      self.dim() > 0,
      "tensor_split expected at least a 1-dimensional tensor, but got a tensor with ",
      self.dim(),
      " dims");
  int64_t dim_ = maybe_wrap_dim(dim, self.dim());
  int64_t num_indices = indices.size();
  std::vector<Tensor> splits(num_indices + 1);
  T start_idx(0);
  for (const auto split_idx : c10::irange(num_indices)) {
    auto end_idx = indices[split_idx];
    splits[split_idx] = at::symint::slice<T>(self, dim_, start_idx, end_idx);
    start_idx = end_idx;
  }
  splits[num_indices] = at::symint::slice<T>(
      self, dim_, start_idx, at::symint::size<T>(self, dim_));
  return splits;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free