skipBoundsChecks Class — pytorch Architecture
Architecture documentation for the skipBoundsChecks class in kernel_backward.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h lines 1403–1435
template <bool skipBoundsChecks>
static CUTLASS_DEVICE void zfillGradKV(
Params const& p,
int32_t key_start,
uint8_t warp_id,
uint8_t lane_id) {
constexpr int kThreadsPerKey = 8;
constexpr int kParallelKeys = kNumThreads / kThreadsPerKey;
static_assert(kBlockSizeJ % kParallelKeys == 0, "");
// This function is not really optimized, but should rarely be used
// It's only used when some keys are "useless" and don't attend to
// any query, due to causal masking
int thread_id = 32 * warp_id + lane_id;
int k_shift = lane_id % kThreadsPerKey;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) {
int key = key_start + j + (thread_id / kThreadsPerKey);
if (!skipBoundsChecks && key >= p.num_keys) {
continue;
}
auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
gv_ptr[k] = scalar_t(0);
}
for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) {
gk_ptr[k] = scalar_t(0);
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free