Home / Class/ CacheOpA Class — pytorch Architecture

CacheOpA Class — pytorch Architecture

Architecture documentation for the CacheOpA class in dq_mma_multistage.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h lines 59–306

template<
    /// Size of the Gemm problem - concept: gemm::GemmShape<>
    typename Shape_,
    /// Iterates over tiles of A operand in global memory
    //  (concept: ReadableTileIterator | ForwardTileIterator |
    //  MaskedTileIterator)
    typename IteratorA_,
    /// Iterates over tiles of A operand in shared memory
    /// (concept: WriteableTileIterator | RandomAccessTileIterator)
    typename SmemIteratorA_,
    /// Cache operation for operand A
    cutlass::arch::CacheOperation::Kind CacheOpA,
    /// Iterates over tiles of B operand in global memory
    //  (concept: ReadableTileIterator | ForwardTileIterator |
    //  MaskedTileIterator)
    typename IteratorB_,
    /// Iterates over tiles of B operand in shared memory
    /// (concept: WriteableTileIterator | RandomAccessTileIterator)
    typename SmemIteratorB_,
    /// Cache operation for operand B
    cutlass::arch::CacheOperation::Kind CacheOpB,
    /// Data type for the scales
    typename IteratorScale_,
    /// Iterators over scales in shared memory
    typename SmemIteratorScale_,
    /// Data type of accumulator matrix
    typename ElementC_,
    /// Data type of accumulator matrix
    typename LayoutC_,
    /// Policy describing tuning details (concept: MmaPolicy)
    typename Policy_,
    /// Number of stages,
    int Stages,
    /// Converter for B matrix applited immediately after the LDS
    typename TransformBAfterLDS_,
    /// Use zfill or predicate for out-of-bound cp.async
    SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
    /// Used for partial specialization
    typename Enable = bool>
class DqMmaMultistage: public DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages> {
public:
    ///< Base class
    using Base = DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages>;
    ///< Size of the Gemm problem - concept: gemm::GemmShape<>
    using Shape = Shape_;
    ///< Iterates over tiles of A operand in global memory
    using IteratorA = IteratorA_;
    ///< Iterates over tiles of B operand in global memory
    using IteratorB = IteratorB_;
    ///< Data type of accumulator matrix
    using ElementC = ElementC_;
    ///< Layout of accumulator matrix
    using LayoutC = LayoutC_;
    ///< Policy describing tuning details
    using Policy = Policy_;

    using IteratorScale = IteratorScale_;
    using ElementScale  = typename IteratorScale::Element;
    using LayoutScale   = typename IteratorScale::Layout;

    using SmemIteratorA     = SmemIteratorA_;
    using SmemIteratorB     = SmemIteratorB_;
    using SmemIteratorScale = SmemIteratorScale_;

    static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
    static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;

    using TransformBAfterLDS = TransformBAfterLDS_;

    //
    // Dependent types
    //

    /// Fragment of operand Scale loaded from global memory;
    using FragmentScale = typename IteratorScale::Fragment;

    /// Fragment of accumulator tile
    using FragmentC = typename Policy::Operator::FragmentC;

    /// Warp-level Mma
    using Operator = typename Policy::Operator;

    /// Minimum architecture is Sm80 to support cp.async
    using ArchTag = arch::Sm80;

    using Dequantizer =
        warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB, ElementScale, LayoutScale, 32>;

    /// Complex transform on A operand
    static ComplexTransform const kTransformA = Operator::kTransformA;

    /// Complex transform on B operand
    static ComplexTransform const kTransformB = Operator::kTransformB;

    /// Internal structure exposed for introspection.
    struct Detail {

        static_assert(Base::kWarpGemmIterations > 1,
                      "The pipelined structure requires at least two warp-level "
                      "GEMM operations.");

        /// Number of cp.async instructions to load one stage of operand A
        static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount;

        /// Number of cp.async instructions to load one stage of operand B
        static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount;

        /// Number of stages
        static int const kStages = Stages;

        /// Number of cp.async instructions to load on group of operand A
        static int const kAccessesPerGroupA =
            (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;

        /// Number of cp.async instructions to load on group of operand B
        static int const kAccessesPerGroupB =
            (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
    };

private:
    using WarpFragmentA = typename Operator::FragmentA;
    using WarpFragmentB = typename Operator::FragmentB;
    Dequantizer warp_dequantizer_;

    using ElementB          = typename IteratorB::Element;
    using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;

    static constexpr bool RequiresTileInterleave =
        layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
    static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
                  "Layout K must match threadblockK");

private:
    //
    // Data members
    //

    /// Iterator to write threadblock-scoped tile of A operand to shared memory
    SmemIteratorA smem_iterator_A_;

    /// Iterator to write threadblock-scoped tile of B operand to shared memory
    SmemIteratorB smem_iterator_B_;

    /// Iterator to write threadblock-scoped tile of scale operand to shared memory
    SmemIteratorScale smem_iterator_scale_;

public:
    /// Construct from tensor references
    CUTLASS_DEVICE
    DqMmaMultistage(
        ///< Shared storage needed for internal use by threadblock-scoped GEMM
        typename Base::SharedStorage& shared_storage,
        ///< ID within the threadblock
        int thread_idx,
        ///< ID of warp
        int warp_idx,
        ///< ID of each thread within a warp
        int lane_idx):
        Base(shared_storage, thread_idx, warp_idx, lane_idx),
        warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
                          (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
                          lane_idx),
        smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
        smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
        smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
    {
        // Compute warp location within threadblock tile by mapping the warp_id to
        // three coordinates:
        //   _m: the warp's position within the threadblock along the M dimension
        //   _n: the warp's position within the threadblock along the N dimension
        //   _k: the warp's position within the threadblock along the K dimension

        int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
        int warp_idx_k  = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);

        int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
        int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;

        // Add per-warp offsets in units of warp-level tiles
        this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
        this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
    }

    CUTLASS_DEVICE
    void
    copy_tiles_and_advance(IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0)
    {
        iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector);
        this->smem_iterator_A_.set_iteration_index(group_start_A);

        // Async Copy for operand A
        CUTLASS_PRAGMA_UNROLL
        for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
            if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
                typename IteratorA::AccessType* dst_ptr =
                    reinterpret_cast<typename IteratorA::AccessType*>(this->smem_iterator_A_.get());

                int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value
                                      * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8;

                CUTLASS_PRAGMA_UNROLL
                for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
                    auto gmem_ptr = iterator_A.get();

                    if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
                        cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
                    }
                    else {
                        cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(dst_ptr + v, gmem_ptr, iterator_A.valid());
                    }

                    ++iterator_A;
                }

                ++this->smem_iterator_A_;
            }
        }

        iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector);
        this->smem_iterator_B_.set_iteration_index(group_start_B);

        // Async Copy for operand B
        CUTLASS_PRAGMA_UNROLL
        for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
            if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
                typename IteratorB::AccessType* dst_ptr =
                    reinterpret_cast<typename IteratorB::AccessType*>(this->smem_iterator_B_.get());

                int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value
                                      * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8;

                CUTLASS_PRAGMA_UNROLL
                for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
                    auto gmem_ptr = iterator_B.get();

                    if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
                        cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
                    }
                    else {
                        cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(dst_ptr + v, gmem_ptr, iterator_B.valid());
                    }

                    ++iterator_B;
                }
                ++this->smem_iterator_B_;
            }
        }
    }

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free