ScaleOperandA_ Class — pytorch Architecture
Architecture documentation for the ScaleOperandA_ class in mma_from_smem.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 340–515
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
// BEGIN smem
/// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA_,
/// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
/// Max GEMM problem size in K dimension
int MaxK,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to B operand
typename TransformB_ = NumericArrayConverter<
typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool>
class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
Shape_,
MaxK,
Policy_,
2,
typename WarpIteratorA_::Layout> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<
Shape_,
MaxK,
Policy_,
2,
typename WarpIteratorA_::Layout>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
static constexpr bool ScaleOperandA = ScaleOperandA_;
using WarpIteratorA = WarpIteratorA_;
///< loads fragments of A_scale from shared memory if operand A scaling is
///< enabled. otherwise no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
ScaleOperandA,
WarpIteratorA,
NoOpWarpIteratorScale<typename WarpIteratorA::TensorRef>>::type;
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorB = SmemIteratorB_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// statically assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert(
(Base::kStages == 2),
"MmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
/// fragment type of OperandA elementwise scaling matrix. (almost) empty
/// if operand A scaling is disabled.
using WarpFragmentAScale = typename WarpIteratorAScale::Fragment;
using WarpFragmentB = typename Operator::FragmentB;
/// applies scaling factor to operand A fragment if operand A scaling is
/// enabled. otherwise no-op.
using FragmentAScaler = FragmentElementwiseScaler<
WarpFragmentA,
WarpFragmentAScale,
ScaleOperandA>;
protected:
// /// Iterator to write threadblock-scoped tile of A operand to shared memory
// SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to load a warp-scoped tile of A operand from intermediate
/// accumulator tile
WarpIteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of A_scale from intermediate
/// accumulator tile (only used if ScaleOperandA_ is true)
WarpIteratorAScale warp_tile_iterator_A_scale_;
public:
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
typename Base::TensorRefA a, // Operand A in shared memory
typename Base::TensorRefA a_scale, // Operand A_scale in shared memory
typename Base::TensorRefB
b_staging, // staging memory for loading tiles of B
int thread_idx,
int warp_idx,
int lane_idx)
: Base(b_staging, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(a, lane_idx),
warp_tile_iterator_A_scale_(a_scale, lane_idx),
smem_iterator_B_(b_staging, thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_A_scale_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
/// Construct from tensor references
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
typename Base::TensorRefA a, ///< Operand A in shared memory
typename Base::TensorRefB b_staging, ///< staging memory for loading B
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx) ///< ID of each thread within a warp
: Base(b_staging, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(a, lane_idx),
smem_iterator_B_(b_staging, thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free