RocblasGemmOp Class — pytorch Architecture
Architecture documentation for the RocblasGemmOp class in GemmRocblas.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/tunable/GemmRocblas.h lines 137–172
template <typename T>
class RocblasGemmOp : public Callable<GemmParams<T>> {
public:
RocblasGemmOp(int solution) : solution_{solution} {}
TuningStatus Call(const GemmParams<T>* params) override {
auto input_output_type = RocBlasDataTypeFor<T>();
if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32 && input_output_type == rocblas_datatype_f32_r)
return FAIL; // no support for TF32 in rocBLAS
auto compute_type = RocBlasComputeTypeFor<T>();
auto h_a = DoCastForHalfOrBfloat16(params->alpha);
auto h_b = DoCastForHalfOrBfloat16(params->beta);
auto status = rocblas_gemm_ex(
(rocblas_handle)at::cuda::getCurrentCUDABlasHandle(),
_rocblasOpFromChar(params->transa),
_rocblasOpFromChar(params->transb),
params->m, params->n, params->k,
&h_a,
params->a, input_output_type, params->lda,
params->b, input_output_type, params->ldb,
&h_b,
params->c, input_output_type, params->ldc,
params->c, input_output_type, params->ldc,
compute_type,
rocblas_gemm_algo_solution_index,
solution_,
rocblas_gemm_flags_none);
if (status != rocblas_status_success) {
return FAIL;
}
return OK;
}
private:
int solution_;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free