mkldnn_gemm Class — pytorch Architecture
Architecture documentation for the mkldnn_gemm class in Matmul.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mkldnn/Matmul.cpp lines 213–267
template<typename scalar_t>
inline typename std::enable_if_t<
std::is_same_v<scalar_t, c10::BFloat16>,
bool>
mkldnn_gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
float alpha,
const scalar_t *a_data, int64_t lda,
const scalar_t *b_data, int64_t ldb,
float beta,
float* c_data, int64_t ldc) {
// introduce heuristic to validate dispatch to MKLDNN
// (m * n * k <= 16 * 16 * 16)
bool bf16_usable = use_mkldnn_bf16_matmul();
if (!bf16_usable) {
return false;
}
ideep::attr_t op_attr;
// Use mkldnn post ops to perform the add.
if (beta != 0.0f) {
op_attr = ideep::attr_t::fuse_sum();
}
// NOTE: View as c-contiguous to avoid extra reordering in mkldnn
// Use identity: C = AB <=> C^T = B^T A^T
ideep::tensor::dims a_strides{{lda, 1}}, b_strides{{ldb, 1}}, c_strides{{ldc, 1}};
if (transa != TransposeType::NoTranspose) {
std::swap(a_strides[0], a_strides[1]);
}
if (transb != TransposeType::NoTranspose) {
std::swap(b_strides[0], b_strides[1]);
}
auto idtype = ideep::tensor::data_type::bf16;
ideep::tensor a = make_ideep_tensor<scalar_t>({k, m}, idtype, a_strides, const_cast<scalar_t*>(a_data));
ideep::tensor b = make_ideep_tensor<scalar_t>({n, k}, idtype, b_strides, const_cast<scalar_t*>(b_data));
ideep::tensor c = make_ideep_tensor<float>({n, m}, ideep::tensor::data_type::f32, c_strides, c_data);
ideep::matmul_forward::compute(
b, a, c, alpha, beta,
ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), op_attr);
if(c.get_data_handle() != c_data){
// ideep will query onednn expect format of output
// if given output format is not expected, ideep will re-init an output buffer
// under this case, we need copy the re-inited buffer back to given buffer
ideep::tensor real_output = make_ideep_tensor<float>({n,m}, idtype, c_strides, c_data);
c.reorder_to(real_output);
}
return true;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free