Home / Class/ CacheOpA Class — pytorch Architecture

CacheOpA Class — pytorch Architecture

Architecture documentation for the CacheOpA class in mma_from_smem.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 1358–1458

template <
    /// Size of the Gemm problem - concept: gemm::GemmShape<>
    typename Shape_,
    /// Iterates over tiles of A operand in global memory
    //  (concept: ReadableTileIterator | ForwardTileIterator |
    //  MaskedTileIterator)
    typename IteratorA_,
    /// Iterates over tiles of A operand in shared memory
    /// (concept: WriteableTileIterator | RandomAccessTileIterator)
    typename SmemIteratorA_,
    typename WarpIteratorA_,
    /// Cache operation for operand A
    cutlass::arch::CacheOperation::Kind CacheOpA,
    /// Iterates over tiles of B operand in global memory
    //  (concept: ReadableTileIterator | ForwardTileIterator |
    //  MaskedTileIterator)
    typename IteratorB_,
    /// Iterates over tiles of B operand in shared memory
    /// (concept: WriteableTileIterator | RandomAccessTileIterator)
    typename SmemIteratorB_,
    /// Cache operation for operand B
    cutlass::arch::CacheOperation::Kind CacheOpB,
    /// Data type of accumulator matrix
    typename ElementC_,
    /// Data type of accumulator matrix
    typename LayoutC_,
    /// Policy describing tuning details (concept: MmaPolicy)
    typename Policy_,
    /// Number of stages,
    int Stages,
    /// Use zfill or predicate for out-of-bound cp.async
    SharedMemoryClearOption SharedMemoryClear,
    int kMaxK,
    /// whether or not to apply elementwise multiplication of operand A by
    /// another matrix in shared memory before usage in A @ B
    bool kScaleOperandA,
    bool kTransposeA>
struct DefaultMmaFromSharedMemory<
    MmaMultistage<
        Shape_,
        IteratorA_,
        SmemIteratorA_,
        CacheOpA,
        IteratorB_,
        SmemIteratorB_,
        CacheOpB,
        ElementC_,
        LayoutC_,
        Policy_,
        Stages,
        SharedMemoryClear>,
    kMaxK,
    WarpIteratorA_,
    kScaleOperandA,
    kTransposeA> {
  using RegularMma = MmaMultistage<
      Shape_,
      IteratorA_,
      SmemIteratorA_,
      CacheOpA,
      IteratorB_,
      SmemIteratorB_,
      CacheOpB,
      ElementC_,
      LayoutC_,
      Policy_,
      Stages,
      SharedMemoryClear>;

  using WarpShape = typename Policy_::Operator::Shape;
  using InstructionShape = typename Policy_::Operator::InstructionShape;
  using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
  static constexpr bool kIsTransposedA =
      WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
  using WarpIteratorA = typename platform::conditional<
      kIsTransposedA,
      typename WarpIteratorTranspose::Iterator,
      WarpIteratorA_>::type;

  // Reduce the number of stages if we don't need that many
  static int constexpr kStagesMax =
      (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
  static int constexpr kStages = cutlass::const_min(Stages, kStagesMax);

  using IteratorB =
      typename cutlass::transform::threadblock::MakeIteratorResidualLast<
          IteratorB_>::Iterator;
  using Mma =
      typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory<
          Shape_,
          WarpIteratorA,
          kScaleOperandA,
          IteratorB,
          SmemIteratorB_,
          RegularMma::kCacheOpB,
          ElementC_,
          LayoutC_,
          Policy_,
          kStages,
          kMaxK>;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free