GetRocBlasGemmStridedBatchedTypeStringAndOps Class — pytorch Architecture
Architecture documentation for the GetRocBlasGemmStridedBatchedTypeStringAndOps class in GemmRocblas.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/tunable/GemmRocblas.h lines 243–275
template <typename T>
auto GetRocBlasGemmStridedBatchedTypeStringAndOps() {
rocblas_handle handle = (rocblas_handle)at::cuda::getCurrentCUDABlasHandle();
int solution_size;
auto input_output_type = RocBlasDataTypeFor<T>();
auto compute_type = RocBlasComputeTypeFor<T>();
// Get the number of available solutions
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
nullptr,
&solution_size));
std::vector<int> solutions(solution_size);
// Get the list of available solutions
TORCH_ROCBLAS_CHECK(rocblas_gemm_ex_get_solutions_by_type(handle,
input_output_type,
input_output_type,
compute_type,
rocblas_gemm_flags_none,
solutions.data(),
&solution_size));
// Sort the solutions in ascending order to make the solution vector deterministic across runs
std::sort(solutions.begin(), solutions.end());
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmStridedBatchedParams<T>>>>> ret;
for (size_t i = 0; i < solutions.size(); ++i) {
auto callable = std::make_unique<RocblasGemmStridedBatchedOp<T>>(solutions[i]);
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
}
return ret;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free