Home / Class/ SharedStorage Class — pytorch Architecture

SharedStorage Class — pytorch Architecture

Architecture documentation for the SharedStorage class in dq_mma_base.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/cutlass_extensions/gemm/threadblock/dq_mma_base.h lines 139–199

    class SharedStorage {
    public:
        //
        // Type definitions
        //

        /// Shape of the A matrix operand in shared memory
        using ShapeA =
            MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;

        /// Shape of the B matrix operand in shared memory
        using ShapeB =
            MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;

    public:
        //
        // Data members
        //

        /// Buffer for A operand
        AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;

        /// Buffer for B operand
        AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;

        /// Buffer to hold scales for threadblock
        AlignedBuffer<ElementScale, Shape::kN> operand_scale;

    public:
        //
        // Methods
        //

        /// Returns a layout object for the A matrix
        CUTLASS_DEVICE
        static typename Operator::LayoutA LayoutA()
        {
            return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
        }

        /// Returns a layout object for the B matrix
        CUTLASS_HOST_DEVICE
        static typename Operator::LayoutB LayoutB()
        {
            return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
        }

        /// Returns a TensorRef to the A operand
        CUTLASS_HOST_DEVICE
        TensorRefA operand_A_ref()
        {
            return TensorRefA{operand_A.data(), LayoutA()};
        }

        /// Returns a TensorRef to the B operand
        CUTLASS_HOST_DEVICE
        TensorRefB operand_B_ref()
        {
            return TensorRefB{operand_B.data(), LayoutB()};
        }
    };

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free