check_tensor_dtype Class — pytorch Architecture
Architecture documentation for the check_tensor_dtype class in sdp_utils_cpp.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/sdp_utils_cpp.h lines 80–106
template <typename dtype_vector>
inline bool check_tensor_dtype(
sdp_params const& params,
dtype_vector allowed_dtypes,
bool debug) {
auto query_dtype = params.query.dtype();
if (!(query_dtype == params.key.dtype() &&
query_dtype == params.value.dtype() &&
(std::find(allowed_dtypes.begin(), allowed_dtypes.end(), query_dtype) !=
allowed_dtypes.end()))) {
if (debug) {
TORCH_WARN(
"Expected query, key and value to all be of dtype: {",
c10::Join(", ", allowed_dtypes),
"}. Got ",
"Query dtype: ",
params.query.dtype(),
", Key dtype: ",
params.key.dtype(),
", and Value dtype: ",
params.value.dtype(),
" instead.");
}
return false;
}
return true;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free