Home / Class/ Shape_ Class — pytorch Architecture

Shape_ Class — pytorch Architecture

Architecture documentation for the Shape_ class in predicated_tile_access_iterator_residual_last.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h lines 899–1068

class PredicatedTileAccessIteratorResidualLast<
    Shape_,
    Element_,
    layout::AffineRankN<2>,
    AdvanceRank,
    ThreadMap_,
    AccessType_,
    false> {
 public:
  static_assert(
      AdvanceRank == 0 || AdvanceRank == 1,
      "Specialization for pitch-linear iterator may along advance along the "
      "contiguous(rank=0) or strided(rank=1) dimension.");

  using Shape = Shape_;
  using Element = Element_;
  using Layout = layout::AffineRankN<2>;
  static int const kAdvanceRank = AdvanceRank;
  using ThreadMap = ThreadMap_;
  using AccessType = AccessType_;

  using Index = typename Layout::Index;
  using LongIndex = typename Layout::LongIndex;

  using TensorRef = TensorRef<Element, Layout>;
  using TensorView = TensorView<Element, Layout>;
  using TensorCoord = typename Layout::TensorCoord;

  using Pointer = Element*;
  using NonConstPointer = typename platform::remove_const<Element>::type*;

  using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates<
      Shape,
      Element,
      layout::PitchLinear,
      AdvanceRank,
      ThreadMap,
      AccessType>;

  static int const kAccessesPerVector =
      ThreadMap::kElementsPerAccess / AccessType::kElements;

  static_assert(
      !(ThreadMap::kElementsPerAccess % AccessType::kElements),
      "Vectors implied by the thread map must be divisible by the access type.");

  /// Predicate vector stores mask to guard accesses
  using Mask = typename UnderlyingPredicates::Mask;

  /// Parameters object is precomputed state and is host-constructible
  class Params {
   public:
    friend PredicatedTileAccessIteratorResidualLast;

   private:
    /// stride of pitch-linear layout (units of Element)
    Coord<Layout::kStrideRank, Layout::LongIndex> stride_;
    /// amount (in byte) to increment pointer to move to next access along
    /// contiguous dimension
    LongIndex inc_contiguous_;
    /// amount (in byte) to increment pointer from first access of current
    /// contiguous dimension to first access of next one.
    LongIndex inc_strided_;
    /// amount (in byte) to increment pointer from last access of current
    /// contiguous dimension to first access of next one.
    LongIndex inc_next_strided_;
    /// amount (in byte) to increment pointer from last access to first access
    /// of next tile
    LongIndex inc_next_;
    /// amount (in byte) to increment pointer from first access of current tile
    /// to first access of next tile
    LongIndex inc_advance_;

   public:
    // Default ctor
    CUTLASS_HOST_DEVICE
    Params()
        : stride_(0),
          inc_contiguous_(0),
          inc_strided_(0),
          inc_next_(0),
          inc_advance_(0) {}

    /// Construct the Params object given a pitch-linear tensor's layout
    CUTLASS_HOST_DEVICE
    Params(Layout const& layout)
        : stride_({layout.stride(0), layout.stride(1)}) {
      inc_contiguous_ =
          (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) *
          sizeof_bits<Element>::value / 8;

      inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) *
          sizeof_bits<Element>::value / 8;

      inc_next_strided_ = inc_strided_ -
          LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_;

      if (kAdvanceRank) {
        // advance along strided dimension
        inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) *
            sizeof_bits<Element>::value / 8;
      } else {
        // advance along contiguous dimension
        inc_advance_ =
            Shape::kContiguous * stride_[0] * sizeof_bits<Element>::value / 8;
      }

      inc_next_ = inc_advance_ -
          LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ -
          LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_;
    };
  };

 private:
  /// Internal pointer type permits fast address arithmetic
  using BytePointer = char*;

  //
  // Data members
  //

  /// Parameters object with precomputed internal state
  Params params_;

  /// Internal pointer to first access of tile
  BytePointer pointer_;

  UnderlyingPredicates the_predicates;
  Mask residual_tile_mask;

 private:
  /// Computes predicates based on internally tracked per-thread offset.
  CUTLASS_DEVICE
  void compute_predicates_(
      /// Extent of the matrix window
      TensorCoord extent,
      /// optionally, simplify predicate calculation during 'steady state' phase
      bool is_steady_state = false) {
    the_predicates.compute_predicates_(extent, is_steady_state);
  }

 public:
  /// Constructs a TileIterator from its precomputed state, threadblock offset,
  /// and thread ID
  CUTLASS_HOST_DEVICE
  PredicatedTileAccessIteratorResidualLast(
      ///< Precomputed parameters object
      Params const& params,
      ///< Pointer to start of tensor
      Pointer pointer,
      ///< Extent of tensor
      TensorCoord extent,
      ///< ID of each participating thread
      int thread_id,
      ///< Initial offset of threadblock
      TensorCoord const& threadblock_offset,
      int const* indices =
          nullptr ///< gather/scatter indices, note no support for
                  ///< gather/scatter at this specialization
      )
      : params_(params),
        pointer_(reinterpret_cast<BytePointer>(
            const_cast<NonConstPointer>(pointer))),
        the_predicates(extent) {
    the_predicates.set_predicates(thread_id, threadblock_offset);

    // update internal pointers
    Layout layout(params_.stride_);
    add_pointer_offset(layout(the_predicates.thread_offset_));
  }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free