Home / Class/ BroadcastLinearIndices Class — pytorch Architecture

BroadcastLinearIndices Class — pytorch Architecture

Architecture documentation for the BroadcastLinearIndices class in LinearAlgebraUtils.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/LinearAlgebraUtils.h lines 556–580

class BroadcastLinearIndices {
 private:
  Tensor linear_indices_;
  bool is_broadcasting_;

 public:
  BroadcastLinearIndices(
      int64_t numel,
      IntArrayRef original_shape,
      IntArrayRef broadcast_shape) : is_broadcasting_(!original_shape.equals(broadcast_shape)) {
    // The assumption is that the broadcast_shape is a materialized broadcast
    // shape of the original_shape. We need to compute the linear indices
    // compatible with the original_shape to access the elements in the original
    // tensor corresponding to the broadcast tensor.
    if (is_broadcasting_) {
      linear_indices_ =
          get_linear_indices(numel, original_shape, broadcast_shape);
    }
  }
  int64_t operator()(int64_t broadcast_linear_index) {
    return is_broadcasting_
        ? linear_indices_.data_ptr<int64_t>()[broadcast_linear_index]
        : broadcast_linear_index;
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free