DefaultScaledGemmOp Class — pytorch Architecture
Architecture documentation for the DefaultScaledGemmOp class in TunableGemm.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/tunable/TunableGemm.h lines 85–116
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
public:
TuningStatus Call(const ScaledGemmParams<T>* params) override {
at::cuda::blas::scaled_gemm(
params->transa,
params->transb,
params->m,
params->n,
params->k,
params->a,
params->a_scale_ptr,
params->lda,
params->a_dtype,
params->a_scale_dtype,
params->a_scaling_type,
params->b,
params->b_scale_ptr,
params->ldb,
params->b_dtype,
params->b_scale_dtype,
params->b_scaling_type,
params->bias_ptr,
params->bias_dtype,
params->c,
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->use_fast_accum,
std::nullopt /* alpha */);
return OK;
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free