Home / Class/ requires_same_num_heads Class — pytorch Architecture

requires_same_num_heads Class — pytorch Architecture

Architecture documentation for the requires_same_num_heads class in sdp_utils_cpp.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/sdp_utils_cpp.h lines 337–375

template <bool requires_same_num_heads=true>
inline bool check_grouped_query_attention(sdp_params const& params, bool debug) {
  const auto q_num_heads = params.query.sym_size(-3);
  const auto k_num_heads = params.key.sym_size(-3);
  const auto v_num_heads = params.value.sym_size(-3);
  const bool same_kv_heads = k_num_heads == v_num_heads;

  if (requires_same_num_heads && !same_kv_heads){
    if (debug) {
      TORCH_WARN(
          "Both fused kernels require key and value to have the same num_heads and batch_size but got: ",
          "Key sizes: ",
          params.key.sizes(),
          ", Value sizes: ",
          params.value.sizes(),
          ", Query sizes: ",
          params.query.sizes(),
          " instead.");
    }
    return false;
  }
  // Check if grouped query attention is supported and validate the number of
  // heads
  if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0))) {
    if (debug) {
      TORCH_WARN(
          "The number of heads in key/value must divide number of heads in query.",
          "Got input Key sizes(): ",
          params.key.sym_size(-3),
          ", Value sizes(): ",
          params.value.sym_size(-3),
          ", Query sizes(): ",
          params.query.sym_size(-3),
          " instead.");
    }
    return false;
  }
  return true;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free