Home / Class/ void Class — pytorch Architecture

void Class — pytorch Architecture

Architecture documentation for the void class in fpA_intB_gemm.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h lines 307–474

    template<typename dummy>
    struct KernelRunner<true, dummy> {
        CUTLASS_DEVICE
        static void run_kernel(Params const& params, SharedStorage& shared_storage)
        {
            using LayoutB = typename Mma::IteratorB::Layout;
            static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
                              || platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
                          "B must be row major/col major OR col major interleaved.");

            // Compute threadblock location
            ThreadblockSwizzle threadblock_swizzle;

            cutlass::gemm::GemmCoord threadblock_tile_offset =
                threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);

            // Early exit if CTA is out of range
            if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
                || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {

                return;
            }

            // Compute initial location in logical coordinates
            cutlass::MatrixCoord tb_offset_A{
                threadblock_tile_offset.m() * Mma::Shape::kM,
                threadblock_tile_offset.k() * params.gemm_k_size,
            };

            cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
                                             threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};

            cutlass::MatrixCoord tb_offset_scale{0, threadblock_tile_offset.n() * Mma::Shape::kN};

            // Problem size is a function of threadblock index in the K dimension
            int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);

            // Compute threadblock-scoped matrix multiply-add
            int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;

            // Compute position within threadblock
            int thread_idx = threadIdx.x;

            // Construct iterators to A and B operands
            typename Mma::IteratorA iterator_A(params.params_A,
                                               params.ref_A.data(),
                                               {params.problem_size.m(), problem_size_k},
                                               thread_idx,
                                               tb_offset_A,
                                               params.gather_A_indices);

            typename Mma::IteratorB iterator_B(params.params_B,
                                               params.ref_B.data(),
                                               {problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
                                               thread_idx,
                                               tb_offset_B,
                                               params.gather_B_indices);

            typename Mma::IteratorScale iterator_scale(params.params_scale,
                                                       params.ref_scale.data(),
                                                       {1, params.problem_size.n()},
                                                       thread_idx,
                                                       tb_offset_scale);

            // Broadcast the warp_id computed by lane 0 to ensure dependent code
            // is compiled as warp-uniform.
            int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
            int lane_idx = threadIdx.x % 32;

            //
            // Main loop
            //
            // Construct thread-scoped matrix multiply
            Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);

            typename Mma::FragmentC accumulators;

            accumulators.clear();

            if (!kSplitKSerial || gemm_k_iterations > 0) {
                // Compute threadblock-scoped matrix multiply-add
                mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
            }

            //
            // Epilogue
            //

            EpilogueOutputOp output_op(params.output_op);

            //
            // Masked tile iterators constructed from members
            //

            threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);

            // assume identity swizzle
            MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM,
                                           threadblock_tile_offset.n() * Mma::Shape::kN);

            int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();

            // Construct the semaphore.
            Semaphore semaphore(params.semaphore + block_idx, thread_idx);

            // If performing a reduction via split-K, fetch the initial synchronization
            if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {

                // Fetch the synchronization lock initially but do not block.
                semaphore.fetch();

                // Indicate which position in a serial reduction the output operator is currently updating
                output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
            }

            // Tile iterator loading from source tensor.
            typename Epilogue::OutputTileIterator iterator_C(params.params_C,
                                                             params.ref_C.data(),
                                                             params.problem_size.mn(),
                                                             thread_idx,
                                                             threadblock_offset,
                                                             params.scatter_D_indices);

            // Tile iterator writing to destination tensor.
            typename Epilogue::OutputTileIterator iterator_D(params.params_D,
                                                             params.ref_D.data(),
                                                             params.problem_size.mn(),
                                                             thread_idx,
                                                             threadblock_offset,
                                                             params.scatter_D_indices);

            Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);

            // Wait on the semaphore - this latency may have been covered by iterator construction
            if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {

                // For subsequent threadblocks, the source matrix is held in the 'D' tensor.
                if (threadblock_tile_offset.k()) {
                    iterator_C = iterator_D;
                }

                semaphore.wait(threadblock_tile_offset.k());
            }

            // Execute the epilogue operator to update the destination tensor.
            epilogue(output_op, iterator_D, accumulators, iterator_C);

            //
            // Release the semaphore
            //

            if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {

                int lock = 0;
                if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {

                    // The final threadblock resets the semaphore for subsequent grids.
                    lock = 0;
                }
                else {
                    // Otherwise, the semaphore is incremented
                    lock = threadblock_tile_offset.k() + 1;
                }

                semaphore.release(lock);
            }
        }
    };

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free