Shape_ Class — pytorch Architecture
Architecture documentation for the Shape_ class in predicated_tile_iterator_residual_last.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h lines 926–1199
class PredicatedTileIteratorResidualLast<
Shape_,
Element_,
layout::AffineRankN<2>,
AdvanceRank,
ThreadMap_,
AccessSize,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRankN<2>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
/// Type used for internal memory accesses
using AccessType = AlignedArray<
Element,
AccessSize,
(AccessSize * sizeof_bits<Element>::value / 8)>;
/// Underlying iterator to compute the addresses
using TileAccessIterator = PredicatedTileAccessIteratorResidualLast<
Shape,
Element,
Layout,
kAdvanceRank,
ThreadMap,
AccessType>;
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<
Element,
ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename TileAccessIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
public:
friend PredicatedTileIteratorResidualLast;
private:
/// Parameters object
typename TileAccessIterator::Params params_;
public:
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) : params_(layout) {}
CUTLASS_HOST_DEVICE
Params() {}
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
private:
//
// Data members
//
/// Data member to the tile access iterator
TileAccessIterator address_iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: address_iterator_(
params.params_,
pointer,
extent,
thread_id,
threadblock_offset) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params,
pointer,
extent,
thread_id,
make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
address_iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
if (kAdvanceRank)
address_iterator_.add_tile_offset(make_Coord(0, 1));
else
address_iterator_.add_tile_offset(make_Coord(1, 0));
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) {
address_iterator_.clear_mask(enable);
}
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) {
address_iterator_.set_residual_tile(enable);
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() {
address_iterator_.enable_mask();
}
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) {
address_iterator_.set_mask(mask);
}
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) {
address_iterator_.get_mask(mask);
}
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, Index pointer_offset) {
load_with_byte_offset(
frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) {
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
int idx = v +
kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
address_iterator_.set_iteration_index(idx);
char const* byte_ptr =
reinterpret_cast<char const*>(address_iterator_.get()) +
byte_offset;
AccessType const* access_ptr =
reinterpret_cast<AccessType const*>(byte_ptr);
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[idx], access_ptr, address_iterator_.valid());
++address_iterator_;
}
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) {
load_with_byte_offset(frag, 0);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
store_with_byte_offset(
frag, pointer_offset * sizeof_bits<Element>::value / 8);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
int idx = v +
kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
char* byte_ptr =
reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
if (address_iterator_.valid()) {
*access_ptr = frag_ptr[idx];
}
++address_iterator_;
}
}
}
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) {
store_with_byte_offset(frag, 0);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free