Home / Class/ MmaTensorOpComputeBWithF16 Class — pytorch Architecture

MmaTensorOpComputeBWithF16 Class — pytorch Architecture

Architecture documentation for the MmaTensorOpComputeBWithF16 class in mma_tensorop_compute_B_with_f16.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h lines 92–305

class MmaTensorOpComputeBWithF16 {
public:
    /// Shape of warp-level matrix operation (concept: GemmShape)
    using Shape = Shape_;

    /// Data type of multiplicand A
    using ElementA = ElementA_;

    /// Layout of multiplicand A
    using LayoutA = LayoutA_;

    /// Data type of multiplicand B
    using ElementB = ElementB_;

    /// Layout of multiplicand B
    using LayoutB = LayoutB_;

    /// Data type of accumulator matrix C
    using ElementC = ElementC_;

    /// Layout of accumulator matrix C
    using LayoutC = LayoutC_;

    /// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
    using Policy = Policy_;

    /// Underlying matrix multiply operator (concept: arch::Mma)
    using ArchMmaOperator = typename Policy::Operator;

    /// Indicates math operator
    using MathOperator = typename ArchMmaOperator::Operator;

    /// Architecture tag from underlying instruction
    using ArchTag = typename ArchMmaOperator::ArchTag;
    static_assert((platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value
                   && platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value)
                      || (platform::is_same<typename ArchMmaOperator::ElementA, bfloat16_t>::value
                          && platform::is_same<typename ArchMmaOperator::ElementB, bfloat16_t>::value
                          && ArchTag::kMinComputeCapability >= 80),
                  "MmaTensorOpCvtBToA only supports underlying HMMA");

    static_assert(platform::is_same<ElementA, half_t>::value
                      || (platform::is_same<ElementA, bfloat16_t>::value && ArchTag::kMinComputeCapability >= 80),
                  "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");

    /// Indicates class of matrix operator
    using OperatorClass = arch::OpClassTensorOp;

    /// Shape of underlying instruction
    using InstructionShape = typename ArchMmaOperator::Shape;

    /// Instruction shape to override shared memory iterators with
    using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;

    static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM,
                  "M dimension of compute instruction must match load");
    static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN,
                  "N dimension of compute instruction must match load");

    static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK;

    static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");

    /// Complex transform on A operand
    static ComplexTransform const kTransformA = ComplexTransform::kNone;

    /// Complex transform on B operand
    static ComplexTransform const kTransformB = ComplexTransform::kNone;

    /// Number of threads participating in warp-level matrix product
    static int const kThreadCount = 32;

    /// Number of partitions along K dimension
    static int const kPartitionsK = PartitionsK_;

public:
    /// Iterates over the A operand in memory
    using IteratorA = MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kM, Shape::kK>,
                                                          Operand::kA,
                                                          ElementA,
                                                          LayoutA,
                                                          MatrixShape<InstructionShape::kM, InstructionShape::kK>,
                                                          Policy::OpDelta::kRow,
                                                          kThreadCount,
                                                          kPartitionsK>;

    /// Storage for A tile
    using FragmentA = typename IteratorA::Fragment;

    /// Storage for transformed A tile
    using TransformedFragmentA = Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;

    /// Iterates over the B operand in memory
    using IteratorB =
        MmaTensorOpMultiplicandTileIterator<MatrixShape<Shape::kK, Shape::kN>,
                                            Operand::kB,
                                            ElementB,
                                            LayoutB,
                                            MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>,
                                            Policy::OpDelta::kRow,
                                            kThreadCount,
                                            kPartitionsK>;

    /// Storage for B tile
    using FragmentB = typename IteratorB::Fragment;

    /// Storage for transformed B tile
    using TransformedFragmentB = Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;

    /// Iterates over the C operand in memory
    using IteratorC = MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>,
                                                         ElementC,
                                                         LayoutC,
                                                         typename ArchMmaOperator::Shape,
                                                         typename Policy::OpDelta>;

    /// Storage for C tile
    using FragmentC = typename IteratorC::Fragment;

    /// Number of mma operations performed
    using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
                                      (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>;

public:
    /// Underlying matrix multiply operator (concept: arch::Mma)
    ArchMmaOperator mma;

public:
    //
    // Methods
    //

    /// Ctor
    CUTLASS_DEVICE
    MmaTensorOpComputeBWithF16() {}

    /// Performs a warp-level matrix multiply-accumulate operation
    CUTLASS_DEVICE
    void operator()(FragmentC&                  D,
                    TransformedFragmentA const& A,
                    TransformedFragmentB const& B,
                    FragmentC const&            C,
                    const int                   warp_tileB_k_offset) const
    {

        using MmaOperandA = typename ArchMmaOperator::FragmentA;
        using MmaOperandB = typename ArchMmaOperator::FragmentB;
        using MmaOperandC = typename ArchMmaOperator::FragmentC;

        static_assert(
            TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
            "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of B");

        D = C;

        MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
        MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
        MmaOperandC*       ptr_D = reinterpret_cast<MmaOperandC*>(&D);

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
        // Serpentine visitation order maximizing reuse of Rb
        CUTLASS_PRAGMA_UNROLL
        for (int n = 0; n < MmaIterations::kColumn; ++n) {

            CUTLASS_PRAGMA_UNROLL
            for (int m = 0; m < MmaIterations::kRow; ++m) {

                int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);

                int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
                if (AccumulatorsInRowMajor) {  // matrix B is reordered
                    mma(ptr_D[n + m_serpentine * MmaIterations::kColumn],
                        ptr_A[m_serpentine],
                        ptr_B[n_offsetB],
                        ptr_D[n + m_serpentine * MmaIterations::kColumn]);
                }
                else {
                    mma(ptr_D[m_serpentine + n * MmaIterations::kRow],
                        ptr_A[m_serpentine],
                        ptr_B[n_offsetB],
                        ptr_D[m_serpentine + n * MmaIterations::kRow]);
                }
            }
        }
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
        // Serpentine visitation order maximizing reuse of Ra
        CUTLASS_PRAGMA_UNROLL
        for (int m = 0; m < MmaIterations::kRow; ++m) {

            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < MmaIterations::kColumn; ++n) {

                int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);

                int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine;
                if (AccumulatorsInRowMajor) {  // matrix B is reordered
                    mma(ptr_D[n_serpentine + m * MmaIterations::kColumn],
                        ptr_A[m],
                        ptr_B[n_serpentine_offsetB],
                        ptr_D[n_serpentine + m * MmaIterations::kColumn]);
                }
                else {
                    mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
                        ptr_A[m],
                        ptr_B[n_serpentine_offsetB],
                        ptr_D[m + n_serpentine * MmaIterations::kRow]);
                }
            }
        }
#else
        assert(0);
#endif
    }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free