Home / Class/ value Class — pytorch Architecture

value Class — pytorch Architecture

Architecture documentation for the value class in mma_tensorop_dequantizer.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h lines 87–190

template<
    /// Underlying matrix multiply operator (concept: MmaTensorOp)
    typename MmaOperator_,
    /// Shape of the warp level matrix multiply (concept: GemmShape)
    typename Shape_>
class MmaTensorOpDequantizer<
    MmaOperator_,
    Shape_,
    Operand::kB,
    bfloat16_t,
    layout::RowMajor,
    32,
    typename platform::enable_if<
        MmaOperator_::ArchTag::kMinComputeCapability >= 80
        && platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {

public:
    /// Mma Operator
    using MmaOperator = MmaOperator_;

    // The architecture specific mma ooperator being used
    using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;

    // Mma Instruction Shape
    using InstructionShape = typename ArchMmaOperator::Shape;

    // This is the ratio of the load instruction vs the compute instruction.
    static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;

    /// Type of the scales
    using ElementScale = bfloat16_t;

    /// Fragment to hold B data before Mma
    using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;

    // Fragment to hold scale data to apply to B before mma
    // We need 1 fp16 per matrix iteration in the N dimension
    static constexpr int kColsPerMmaPerThread = 1;
    using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;

    /// Warp mma shape
    using Shape = Shape_;

    /// Layout of the scales in shared memory
    using Layout = layout::RowMajor;

    /// TensorRef type for loading element from a tensor
    using TensorRef = TensorRef<ElementScale, Layout>;

    CUTLASS_DEVICE
    MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
    {
        const int warp_offset   = warp_idx_n * Shape::kN;
        const int quad          = lane_idx / 4;
        const int thread_offset = warp_offset + quad;
        pointer_                = smem_scales.data() + thread_offset;
    }

    CUTLASS_DEVICE
    void load(FragmentScale& scale_frag)
    {

        CUTLASS_PRAGMA_UNROLL
        for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
            scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
        }
    }

    CUTLASS_DEVICE
    void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
    {
//#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
        using _MmaOperandB        = typename ArchMmaOperator::FragmentB;
        using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
        static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
                          == FragmentDequantizedOperand::kElements,
                      "");

        const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);

        ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
        CUTLASS_PRAGMA_UNROLL
        for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
            static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");

            __nv_bfloat162  scalex2            = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
            __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
            CUTLASS_PRAGMA_UNROLL
            for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) {
                operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
            }
        }
#else
        // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
        // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
        // numerous conversion instructions in GEMM main loop.
        arch::device_breakpoint();
#endif
    }

private:
    ElementScale const* pointer_;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free