Shape_ Class — pytorch Architecture
Architecture documentation for the Shape_ class in predicated_tile_access_iterator_residual_last.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h lines 899–1068
class PredicatedTileAccessIteratorResidualLast<
Shape_,
Element_,
layout::AffineRankN<2>,
AdvanceRank,
ThreadMap_,
AccessType_,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRankN<2>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates<
Shape,
Element,
layout::PitchLinear,
AdvanceRank,
ThreadMap,
AccessType>;
static int const kAccessesPerVector =
ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(
!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the access type.");
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingPredicates::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
public:
friend PredicatedTileAccessIteratorResidualLast;
private:
/// stride of pitch-linear layout (units of Element)
Coord<Layout::kStrideRank, Layout::LongIndex> stride_;
/// amount (in byte) to increment pointer to move to next access along
/// contiguous dimension
LongIndex inc_contiguous_;
/// amount (in byte) to increment pointer from first access of current
/// contiguous dimension to first access of next one.
LongIndex inc_strided_;
/// amount (in byte) to increment pointer from last access of current
/// contiguous dimension to first access of next one.
LongIndex inc_next_strided_;
/// amount (in byte) to increment pointer from last access to first access
/// of next tile
LongIndex inc_next_;
/// amount (in byte) to increment pointer from first access of current tile
/// to first access of next tile
LongIndex inc_advance_;
public:
// Default ctor
CUTLASS_HOST_DEVICE
Params()
: stride_(0),
inc_contiguous_(0),
inc_strided_(0),
inc_next_(0),
inc_advance_(0) {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout)
: stride_({layout.stride(0), layout.stride(1)}) {
inc_contiguous_ =
(LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) *
sizeof_bits<Element>::value / 8;
inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) *
sizeof_bits<Element>::value / 8;
inc_next_strided_ = inc_strided_ -
LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_;
if (kAdvanceRank) {
// advance along strided dimension
inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) *
sizeof_bits<Element>::value / 8;
} else {
// advance along contiguous dimension
inc_advance_ =
Shape::kContiguous * stride_[0] * sizeof_bits<Element>::value / 8;
}
inc_next_ = inc_advance_ -
LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ -
LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_;
};
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
//
// Data members
//
/// Parameters object with precomputed internal state
Params params_;
/// Internal pointer to first access of tile
BytePointer pointer_;
UnderlyingPredicates the_predicates;
Mask residual_tile_mask;
private:
/// Computes predicates based on internally tracked per-thread offset.
CUTLASS_DEVICE
void compute_predicates_(
/// Extent of the matrix window
TensorCoord extent,
/// optionally, simplify predicate calculation during 'steady state' phase
bool is_steady_state = false) {
the_predicates.compute_predicates_(extent, is_steady_state);
}
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of tensor
Pointer pointer,
///< Extent of tensor
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: params_(params),
pointer_(reinterpret_cast<BytePointer>(
const_cast<NonConstPointer>(pointer))),
the_predicates(extent) {
the_predicates.set_predicates(thread_id, threadblock_offset);
// update internal pointers
Layout layout(params_.stride_);
add_pointer_offset(layout(the_predicates.thread_offset_));
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free