CacheOpA Class — pytorch Architecture
Architecture documentation for the CacheOpA class in mma_from_smem.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 1358–1458
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
typename WarpIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
int kMaxK,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
bool kTransposeA>
struct DefaultMmaFromSharedMemory<
MmaMultistage<
Shape_,
IteratorA_,
SmemIteratorA_,
CacheOpA,
IteratorB_,
SmemIteratorB_,
CacheOpB,
ElementC_,
LayoutC_,
Policy_,
Stages,
SharedMemoryClear>,
kMaxK,
WarpIteratorA_,
kScaleOperandA,
kTransposeA> {
using RegularMma = MmaMultistage<
Shape_,
IteratorA_,
SmemIteratorA_,
CacheOpA,
IteratorB_,
SmemIteratorB_,
CacheOpB,
ElementC_,
LayoutC_,
Policy_,
Stages,
SharedMemoryClear>;
using WarpShape = typename Policy_::Operator::Shape;
using InstructionShape = typename Policy_::Operator::InstructionShape;
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
static constexpr bool kIsTransposedA =
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
using WarpIteratorA = typename platform::conditional<
kIsTransposedA,
typename WarpIteratorTranspose::Iterator,
WarpIteratorA_>::type;
// Reduce the number of stages if we don't need that many
static int constexpr kStagesMax =
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
static int constexpr kStages = cutlass::const_min(Stages, kStagesMax);
using IteratorB =
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
IteratorB_>::Iterator;
using Mma =
typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
IteratorB,
SmemIteratorB_,
RegularMma::kCacheOpB,
ElementC_,
LayoutC_,
Policy_,
kStages,
kMaxK>;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free