Home / Class/ DqMmaBase Class — pytorch Architecture

DqMmaBase Class — pytorch Architecture

Architecture documentation for the DqMmaBase class in dq_mma_base.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_base.h lines 91–228

class DqMmaBase {
public:
    ///< Size of the Gemm problem - concept: gemm::GemmShape<>
    using Shape = Shape_;

    ///< Policy describing tuning details
    using Policy = Policy_;

    ///< Type of the scale to be loaded
    using ElementScale = ElementScale_;

    //
    // Dependent types
    //

    /// Warp-level Mma
    using Operator = typename Policy::Operator;

    /// Shape describing the overall GEMM computed from shared memory
    /// by each warp.
    using WarpGemm = typename Policy::Operator::Shape;

    /// Shape describing the number of warps filling the CTA
    using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;

    /// Number of warp-level GEMM operations
    static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);

    static constexpr int kNumKIterationsPerWarpBLoad =
        Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;

    static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
    static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;

    /// Number of stages
    static int const kStages = Stages;

    /// Tensor reference to the A operand
    using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;

    /// Tensor reference to the B operand
    using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;

    //
    // Nested structs
    //

    /// Shared storage object needed by threadblock-scoped GEMM
    class SharedStorage {
    public:
        //
        // Type definitions
        //

        /// Shape of the A matrix operand in shared memory
        using ShapeA =
            MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;

        /// Shape of the B matrix operand in shared memory
        using ShapeB =
            MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;

    public:
        //
        // Data members
        //

        /// Buffer for A operand
        AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;

        /// Buffer for B operand
        AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;

        /// Buffer to hold scales for threadblock
        AlignedBuffer<ElementScale, Shape::kN> operand_scale;

    public:
        //
        // Methods
        //

        /// Returns a layout object for the A matrix
        CUTLASS_DEVICE
        static typename Operator::LayoutA LayoutA()
        {
            return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
        }

        /// Returns a layout object for the B matrix
        CUTLASS_HOST_DEVICE
        static typename Operator::LayoutB LayoutB()
        {
            return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
        }

        /// Returns a TensorRef to the A operand
        CUTLASS_HOST_DEVICE
        TensorRefA operand_A_ref()
        {
            return TensorRefA{operand_A.data(), LayoutA()};
        }

        /// Returns a TensorRef to the B operand
        CUTLASS_HOST_DEVICE
        TensorRefB operand_B_ref()
        {
            return TensorRefB{operand_B.data(), LayoutB()};
        }
    };

protected:
    //
    // Data members
    //

    /// Iterator to load a warp-scoped tile of A operand from shared memory
    typename Operator::IteratorA warp_tile_iterator_A_;

    /// Iterator to load a warp-scoped tile of B operand from shared memory
    typename Operator::IteratorB warp_tile_iterator_B_;

public:
    /// Construct from tensor references
    CUTLASS_DEVICE
    DqMmaBase(
        ///< Shared storage needed for internal use by threadblock-scoped GEMM
        SharedStorage& shared_storage,
        ///< ID within the threadblock
        int thread_idx,
        ///< ID of warp
        int warp_idx,
        ///< ID of each thread within a warp
        int lane_idx):
        warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
        warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
    {
    }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free