value Class — pytorch Architecture
Architecture documentation for the value class in mma_tensorop_dequantizer.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cuda/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h lines 87–190
template<
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_>
class MmaTensorOpDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
bfloat16_t,
layout::RowMajor,
32,
typename platform::enable_if<
MmaOperator_::ArchTag::kMinComputeCapability >= 80
&& platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB, layout::ColumnMajor>::value>::type> {
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the scales
using ElementScale = bfloat16_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand = Array<ElementScale, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale = Array<ElementScale, kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
{
const int warp_offset = warp_idx_n * Shape::kN;
const int quad = lane_idx / 4;
const int thread_offset = warp_offset + quad;
pointer_ = smem_scales.data() + thread_offset;
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag)
{
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
{
//#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn
== FragmentDequantizedOperand::kElements,
"");
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) {
static_assert(ExpandedMmaOperandB::kElements % 2 == 0, "");
__nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]);
__nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) {
operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2);
}
}
#else
// Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should
// happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid
// numerous conversion instructions in GEMM main loop.
arch::device_breakpoint();
#endif
}
private:
ElementScale const* pointer_;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free