ScatterD Class — pytorch Architecture
Architecture documentation for the ScatterD class in epilogue_predicated_tile_iterator.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h lines 70–251
template <
typename ThreadMap_, ///< Thread map (concept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
bool UseCUDAStore = false>
class PredicatedTileIteratorPrefetch {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static_assert(
ThreadMap::Iterations::kRow > 0,
"ThreadMap::Iterations::kRow must be > 0");
static_assert(
ThreadMap::Iterations::kGroup > 0,
"ThreadMap::Iterations::kGroup must be > 0");
static_assert(
ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
static_assert(
ThreadMap::Iterations::kColumn > 0,
"ThreadMap::Iterations::kColumn must be > 0");
/// Fragment object
using Fragment = Array<
Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorParams {
using Base = PredicatedTileIteratorParams;
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: PredicatedTileIteratorParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>()) {}
CUTLASS_HOST_DEVICE
Params(Base const& base) : Base(base) {}
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() {
enable();
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = false;
}
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = true;
}
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer
uint8_t* byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have
/// been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Internal state counter
int state_[3];
/// Scatter indices
int const* indices_;
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(
sizeof(PredicatedTileIteratorParams::stride) == 8,
"Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorPrefetch(
PredicatedTileIteratorParams const& params,
Element* pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const* indices = nullptr)
: params_(params), indices_(indices) {
TensorCoord thread_offset =
ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_column_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] =
((thread_offset.column() + ThreadMap::Delta::kColumn * c) <
extent.column());
}
// Null pointer performs no accesses
if (!pointer) {
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free