kElementsPerAccess Class — pytorch Architecture
Architecture documentation for the kElementsPerAccess class in kernel_backward.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h lines 2498–2598
template <int kElementsPerAccess>
static CUTLASS_DEVICE void computeDelta(
Params const& p,
int32_t query_start,
uint8_t warp_id,
uint8_t lane_id) {
// Each thread computes one value for Delta
// Depending on warp configuration, we might have multiple
// threads of the same warp working on the same row
using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>;
static_assert(kNumThreads >= kBlockSizeI, "");
static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI;
int16_t thread_id = 32 * warp_id + lane_id;
int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine);
int16_t laneRow = thread_id / kNumThreadsPerLine;
bool rowPred = (query_start + laneRow) < p.num_queries;
bool pred = rowPred;
// on windows, previous syntax __restrict__ AccessType*
// resulted in error: "restrict" is not allowed
const AccessType* __restrict__ grad_output_ptr =
reinterpret_cast<const AccessType*>(
p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM +
laneFirstCol);
const AccessType* __restrict__ output_ptr =
reinterpret_cast<const AccessType*>(
p.output_ptr + (query_start + laneRow) * p.o_strideM() +
laneFirstCol);
static constexpr int64_t kMaxIters =
kMaxK / (kElementsPerAccess * kNumThreadsPerLine);
constexpr int kPipelineStages = 2;
accum_t delta_value = accum_t(0);
using GlobalLoad =
cutlass::arch::global_load<AccessType, sizeof(AccessType)>;
AccessType frag_grad_output[kPipelineStages];
AccessType frag_output[kPipelineStages];
auto loadAndIncrement = [&](int ld_pos, bool is_valid) {
frag_grad_output[ld_pos].clear();
frag_output[ld_pos].clear();
GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid);
GlobalLoad(frag_output[ld_pos], output_ptr, is_valid);
grad_output_ptr += kNumThreadsPerLine;
output_ptr += kNumThreadsPerLine;
};
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kPipelineStages - 1; ++iter) {
int ld_pos = iter % kPipelineStages;
pred = pred &&
(laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) <
p.head_dim_value;
loadAndIncrement(ld_pos, pred);
}
auto columnIteration = [&](int iter) {
// Load for next iter
int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages;
pred = pred &&
(laneFirstCol +
(iter + kPipelineStages - 1) * kElementsPerAccess *
kNumThreadsPerLine) < p.head_dim_value;
loadAndIncrement(ld_pos, pred);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < AccessType::kElements; ++i) {
delta_value += accum_t(frag_output[iter % kPipelineStages][i]) *
accum_t(frag_grad_output[iter % kPipelineStages][i]);
}
};
// If we have a small lower-bound for K, we can unroll the loop
if (kMaxK <= 256) {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kMaxIters; ++iter) {
columnIteration(iter);
}
} else {
int num_iters =
ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) *
(kElementsPerAccess * kNumThreadsPerLine);
for (int iter = 0; iter < num_iters; ++iter) {
columnIteration(iter);
}
}
// Reduce between workers
static_assert(
kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 ||
kNumThreadsPerLine == 4,
"");
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kNumThreadsPerLine; i *= 2) {
delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i);
}
// Store in gmem
if (rowPred) {
p.delta_ptr[query_start + laneRow] = delta_value;
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free