supports_gqa Class — pytorch Architecture
Architecture documentation for the supports_gqa class in sdp_utils_cpp.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/sdp_utils_cpp.h lines 377–432
template <bool supports_gqa, bool requires_same_num_heads=true>
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
// This is expected to be called after check_tensor_shapes ensuring that the
// size() calls won't error since the inputs are all 4 dimensional
auto q_batch_size = params.query.sym_size(0);
auto k_batch_size = params.key.sym_size(0);
auto v_batch_size = params.value.sym_size(0);
bool same_batch_size =
q_batch_size == k_batch_size && q_batch_size == v_batch_size;
auto q_num_heads = params.query.sym_size(-3);
auto k_num_heads = params.key.sym_size(-3);
auto v_num_heads = params.value.sym_size(-3);
bool same_num_heads =
q_num_heads == k_num_heads && q_num_heads == v_num_heads;
if (!same_batch_size){
if(debug) {
TORCH_WARN(
"For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
"Query.sizes(): ",
params.query.sizes(),
", Key.sizes(): ",
params.key.sizes(),
", Value.sizes(): ",
params.value.sizes(),
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
}
return false;
}
if(params.enable_gqa && supports_gqa){
return check_grouped_query_attention<requires_same_num_heads>(params, debug);
}
// same num heads condition for non-gqa case
if (!same_num_heads){
if (debug) {
TORCH_WARN(
"For dense input, both fused kernels require query, key and value to have the same num_heads. ",
"Query.sizes(): ",
params.query.sizes(),
", Key sizes(): ",
params.key.sizes(),
", Value sizes(): ",
params.value.sizes(),
" instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
}
return false;
}
// If all checks pass, return true
return true;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free