Home / Class/ supports_gqa Class — pytorch Architecture

supports_gqa Class — pytorch Architecture

Architecture documentation for the supports_gqa class in sdp_utils_cpp.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/sdp_utils_cpp.h lines 377–432

template <bool supports_gqa, bool requires_same_num_heads=true>
inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) {
  // This is expected to be called after check_tensor_shapes ensuring that the
  // size() calls won't error since the inputs are all 4 dimensional

  auto q_batch_size = params.query.sym_size(0);
  auto k_batch_size = params.key.sym_size(0);
  auto v_batch_size = params.value.sym_size(0);

  bool same_batch_size =
      q_batch_size == k_batch_size && q_batch_size == v_batch_size;

  auto q_num_heads = params.query.sym_size(-3);
  auto k_num_heads = params.key.sym_size(-3);
  auto v_num_heads = params.value.sym_size(-3);

  bool same_num_heads =
      q_num_heads == k_num_heads && q_num_heads == v_num_heads;

  if (!same_batch_size){
    if(debug) {
      TORCH_WARN(
          "For dense inputs, both fused kernels require query, key and value to have the same batch_size. ",
          "Query.sizes(): ",
          params.query.sizes(),
          ", Key.sizes(): ",
          params.key.sizes(),
          ", Value.sizes(): ",
          params.value.sizes(),
          " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
    }
    return false;
  }

  if(params.enable_gqa && supports_gqa){
    return check_grouped_query_attention<requires_same_num_heads>(params, debug);
  }

  // same num heads condition for non-gqa case
  if (!same_num_heads){
    if (debug) {
      TORCH_WARN(
          "For dense input, both fused kernels require query, key and value to have the same num_heads. ",
          "Query.sizes(): ",
          params.query.sizes(),
          ", Key sizes(): ",
          params.key.sizes(),
          ", Value sizes(): ",
          params.value.sizes(),
          " instead. To broadcast dense inputs, try using unsqueeze and expand_to before passing them into the kernel.");
    }
    return false;
  }
  // If all checks pass, return true
  return true;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free