requires_same_num_heads Class — pytorch Architecture
Architecture documentation for the requires_same_num_heads class in sdp_utils_cpp.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/sdp_utils_cpp.h lines 337–375
template <bool requires_same_num_heads=true>
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
const auto q_num_heads = params.query.sym_size(-3);
const auto k_num_heads = params.key.sym_size(-3);
const auto v_num_heads = params.value.sym_size(-3);
const bool same_kv_heads = k_num_heads == v_num_heads;
if (requires_same_num_heads && !same_kv_heads){
if (debug) {
TORCH_WARN(
"Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
"Key sizes: ",
params.key.sizes(),
", Value sizes: ",
params.value.sizes(),
", Query sizes: ",
params.query.sizes(),
" instead.");
}
return false;
}
// Check if grouped query attention is supported and validate the number of
// heads
if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0))) {
if (debug) {
TORCH_WARN(
"The number of heads in key/value must divide number of heads in query.",
"Got input Key sizes(): ",
params.key.sym_size(-3),
", Value sizes(): ",
params.value.sym_size(-3),
", Query sizes(): ",
params.query.sym_size(-3),
" instead.");
}
return false;
}
return true;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free