kForceReloadK Class — pytorch Architecture
Architecture documentation for the kForceReloadK class in kernel_backward.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h lines 2375–2411
template <bool kForceReloadK>
static CUTLASS_DEVICE void prologueQkNextIteration(
SharedStorage& shared_storage,
Params const& p,
int32_t query_start,
int32_t key_start,
uint8_t warp_id,
uint8_t lane_id) {
if (query_start >= p.num_queries || key_start >= p.num_keys) {
return;
}
static constexpr bool kReloadK =
kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat;
int thread_id = 32 * warp_id + lane_id;
typename MatmulQK::Mma::IteratorA iterator_A(
{int32_t(p.k_strideM)},
const_cast<scalar_t*>(p.key_ptr + key_start * p.k_strideM),
{p.num_keys - key_start, p.head_dim},
thread_id,
cutlass::MatrixCoord{0, 0});
typename MatmulQK::Mma::IteratorB iterator_B(
{int32_t(p.q_strideM)},
const_cast<scalar_t*>(p.query_ptr + query_start * p.q_strideM),
{p.head_dim, p.num_queries - query_start},
thread_id,
cutlass::MatrixCoord{0, 0});
MatmulQK::Mma::prologue<kReloadK, true>(
shared_storage.mm_qk_k(),
shared_storage.mm_qk_q(),
iterator_A,
iterator_B,
thread_id,
p.head_dim);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free