DqMmaPipelined Class — pytorch Architecture
Architecture documentation for the DqMmaPipelined class in dq_mma_pipelined.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h lines 94–211
class DqMmaPipelined: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2> {
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator,
typename Base::WarpGemm,
Operand::kB,
typename SmemIteratorScale::Fragment::Element,
LayoutScale,
32>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// statically assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave =
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
):
Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM,
lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free