Home / Class/ kM Class — pytorch Architecture

kM Class — pytorch Architecture

Architecture documentation for the kM class in mma_from_smem.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h lines 79–125

template <
    typename Shape_,
    typename Element_,
    typename Layout_,
    typename Padding_>
class AccumulatorSharedStorage {
 public:
  //
  // Type definitions
  //
  using Shape = Shape_;
  using Element = Element_;
  using Layout = Layout_;
  using Padding = Padding_;

  /// Tensor reference to the accumulator
  using TensorRefAccum = cutlass::TensorRef<Element, Layout>;

  /// Shape of the accumulator matrix in shared memory
  using ShapeAccum = cutlass::
      MatrixShape<Shape::kM + Padding::kRow, Shape::kN + Padding::kColumn>;

 public:
  //
  // Data members
  //

  /// Buffer for accumulator
  cutlass::AlignedBuffer<Element, ShapeAccum::kCount> accum;

 public:
  //
  // Methods
  //

  /// Returns a layout object for the Accum matrix
  CUTLASS_DEVICE
  static Layout LayoutAccum() {
    return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn});
  }

  /// Returns a TensorRef to the Accumulator
  CUTLASS_HOST_DEVICE
  TensorRefAccum accum_ref() {
    return TensorRefAccum{accum.data(), LayoutAccum()};
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free