BroadcastLinearIndices Class — pytorch Architecture
Architecture documentation for the BroadcastLinearIndices class in LinearAlgebraUtils.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/LinearAlgebraUtils.h lines 556–580
class BroadcastLinearIndices {
private:
Tensor linear_indices_;
bool is_broadcasting_;
public:
BroadcastLinearIndices(
int64_t numel,
IntArrayRef original_shape,
IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
// The assumption is that the broadcast_shape is a materialized broadcast
// shape of the original_shape. We need to compute the linear indices
// compatible with the original_shape to access the elements in the original
// tensor corresponding to the broadcast tensor.
if (is_broadcasting_) {
linear_indices_ =
get_linear_indices(numel, original_shape, broadcast_shape);
}
}
int64_t operator()(int64_t broadcast_linear_index) {
return is_broadcasting_
? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
: broadcast_linear_index;
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free