index_func_meta_impl Class — pytorch Architecture
Architecture documentation for the index_func_meta_impl class in TensorAdvancedIndexing.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/TensorAdvancedIndexing.cpp lines 354–436
template <typename Meta>
static void index_func_meta_impl(
Meta& meta,
const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& source,
std::string_view func) {
auto numel = index.numel();
TORCH_CHECK_INDEX(
index.dim() <= 1,
func,
"_(): Index is supposed to be a vector, but got dim: ",
index.dim(),
" with type: ",
index.scalar_type(),
" and size: ",
index.sizes());
TORCH_CHECK(
index.scalar_type() == ScalarType::Long ||
index.scalar_type() == ScalarType::Int,
func,
"_(): Expected dtype int32/int64 for index but got: ",
index.scalar_type());
TORCH_CHECK(
self.scalar_type() == source.scalar_type(),
func,
"_(): self (",
self.scalar_type(),
") and source (",
source.scalar_type(),
") must have the same scalar type");
TORCH_CHECK(
dim == 0 || dim < source.dim(),
func,
"_(): Indexing dim ",
dim,
" is out of bounds of the source tensor with dim ",
source.dim());
TORCH_CHECK(
numel == (source.dim() == 0 ? 1 : source.size(dim)),
func,
"_(): Number of indices (",
numel,
") should be equal to source.size(dim): (",
source.size(dim),
"), for dim: ",
dim);
auto self_sizes = self.sizes().vec();
auto source_sizes = source.sizes().vec();
if (source.dim() != 0 && self.dim() != 0) {
self_sizes.erase(self_sizes.begin() + dim);
source_sizes.erase(source_sizes.begin() + dim);
}
TORCH_CHECK(
self_sizes == source_sizes,
"source tensor shape must match self tensor shape, excluding the specified dimension. Got self.shape = ",
self.sizes(),
" source.shape = ",
source.sizes());
auto& result = meta.maybe_get_output(0);
bool is_defined = result.defined();
meta.set_output_raw_strided(0, self.sizes(), {}, self.options());
if (is_defined) {
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, index);
at::assert_no_overlap(result, source);
}
// A hack to run TensorIterator checks in the meta function.
// See comment:
// https://github.com/pytorch/pytorch/pull/65993#discussion_r760307417
// TODO: (@krshrimali) Try inheriting from TensorIteratorBase instead.
if (result.device() == kMeta && result.dim() > 0) {
auto selfSlice = result.select(dim, 0);
auto sourceSlice = source.select(dim, 0);
auto iter =
TensorIterator::borrowing_binary_op(selfSlice, selfSlice, sourceSlice);
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free