Home / Class/ DqMmaPipelined Class — pytorch Architecture

DqMmaPipelined Class — pytorch Architecture

Architecture documentation for the DqMmaPipelined class in dq_mma_pipelined.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h lines 94–211

class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> {
public:
    ///< Base class
    using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;

    using Shape     = Shape_;      ///< Size of the Gemm problem - concept: gemm::GemmShape<>
    using IteratorA = IteratorA_;  ///< Iterates over tiles of A operand in global memory
    using IteratorB = IteratorB_;  ///< Iterates over tiles of B operand in global memory
    using ElementC  = ElementC_;   ///< Data type of accumulator matrix
    using LayoutC   = LayoutC_;    ///< Layout of accumulator matrix
    using Policy    = Policy_;     ///< Policy describing tuning details

    using IteratorScale = IteratorScale_;
    using ElementScale  = typename IteratorScale::Element;
    using LayoutScale   = typename IteratorScale::Layout;

    using SmemIteratorA     = SmemIteratorA_;
    using SmemIteratorB     = SmemIteratorB_;
    using SmemIteratorScale = SmemIteratorScale_;

    using TransformBAfterLDG = TransformBAfterLDG_;
    using TransformBAfterLDS = TransformBAfterLDS_;

    //
    // Dependent types
    //

    /// Fragment of operand A loaded from global memory
    using FragmentA = typename IteratorA::Fragment;

    /// Fragment of operand B loaded from global memory
    using FragmentB = typename IteratorB::Fragment;

    /// Fragment of operand Scale loaded from global memory;
    using FragmentScale = typename IteratorScale::Fragment;

    /// Fragment of accumulator tile
    using FragmentC = typename Policy::Operator::FragmentC;

    /// Warp-level Mma
    using Operator = typename Policy::Operator;

    /// Obtain the arch tag from the warp-level operator
    using ArchTag = typename Policy::Operator::ArchTag;

    using Dequantizer = warp::MmaTensorOpDequantizer<Operator,
                                                     typename Base::WarpGemm,
                                                     Operand::kB,
                                                     typename SmemIteratorScale::Fragment::Element,
                                                     LayoutScale,
                                                     32>;

    /// Complex transform on A operand
    static ComplexTransform const kTransformA = Operator::kTransformA;

    /// Complex transform on B operand
    static ComplexTransform const kTransformB = Operator::kTransformB;

    // statically assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
    static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");

private:
    using WarpFragmentA = typename Operator::FragmentA;
    using WarpFragmentB = typename Operator::FragmentB;
    Dequantizer warp_dequantizer_;

    using ElementB          = typename IteratorB::Element;
    using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;

    static constexpr bool RequiresTileInterleave =
        layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
    static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
                  "Layout K must match threadblockK");

protected:
    /// Iterator to write threadblock-scoped tile of A operand to shared memory
    SmemIteratorA smem_iterator_A_;

    /// Iterator to write threadblock-scoped tile of B operand to shared memory
    SmemIteratorB smem_iterator_B_;

    /// Iterator to write threadblock-scoped tile of scale operand to shared memory
    SmemIteratorScale smem_iterator_scale_;

public:
    /// Construct from tensor references
    CUTLASS_DEVICE
    DqMmaPipelined(typename Base::SharedStorage&
                       shared_storage,  ///< Shared storage needed for internal use by threadblock-scoped GEMM
                   int thread_idx,      ///< ID within the threadblock
                   int warp_idx,        ///< ID of warp
                   int lane_idx         ///< ID of each thread within a warp
                   ):
        Base(shared_storage, thread_idx, warp_idx, lane_idx),
        warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
                          (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
                          lane_idx),
        smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
        smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
        smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
    {

        // Compute warp location within threadblock tile by mapping the warp_id to
        // three coordinates:
        //   _m: the warp's position within the threadblock along the M dimension
        //   _n: the warp's position within the threadblock along the N dimension
        //   _k: the warp's position within the threadblock along the K dimension

        int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
        int warp_idx_k  = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);

        int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
        int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;

        // Add per-warp offsets in units of warp-level tiles
        this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
        this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
    }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free