GemmStridedBatchedParams Class — pytorch Architecture
Architecture documentation for the GemmStridedBatchedParams class in GemmCommon.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/cuda/tunable/GemmCommon.h lines 480–586
template <typename T, typename C_Dtype = T>
struct GemmStridedBatchedParams : OpParams {
GemmStridedBatchedParams() = default;
GemmStridedBatchedParams(const GemmStridedBatchedParams&) = default;
GemmStridedBatchedParams(GemmStridedBatchedParams&&) noexcept = default;
GemmStridedBatchedParams& operator=(const GemmStridedBatchedParams&) = default;
GemmStridedBatchedParams& operator=(GemmStridedBatchedParams&&) noexcept = default;
~GemmStridedBatchedParams() override = default;
std::string BLASSignature() const override {
std::string alpha_str = to_string_opmath<T>(alpha);
std::string beta_str = to_string_opmath<T>(beta);
return fmt::sprintf("- { function: matmul, M: %ld, N: %ld, K: %ld, lda: %ld, ldb: %ld, ldc: %ld, ldd: %ld, stride_a: %ld, stride_b: %ld, stride_c: %ld, stride_d: %ld, "
"alpha: %s, beta: %s, transA: %c, transB: %c, batch_count: %ld, a_type: %s, b_type: %s, c_type: %s, d_type: %s, scale_type: %s, compute_type: %s }",
m, n, k, lda, ldb, ldc, ldc, stride_a, stride_b, stride_c, stride_c, alpha_str, beta_str, transa, transb, batch,
BLASTypeName<T>(T{}), BLASTypeName<T>(T{}), BLASTypeName<C_Dtype>(C_Dtype{}), BLASTypeName<T>(T{}), ComputeTypeFor<T>(), ComputeTypeFor<T>());
}
std::string Signature() const override {
return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld_ld_%ld_%ld_%ld", transa, transb, m, n, k, batch, lda, ldb, ldc);
}
size_t GetSizeA() const {
size_t size_stride = stride_a * batch;
size_t size_dense = m * k * batch;
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
}
size_t GetSizeB() const {
size_t size_stride = stride_b * batch;
size_t size_dense = k * n * batch;
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
}
size_t GetSizeC() const {
size_t size_stride = stride_c * batch;
size_t size_dense = m * n * batch;
return sizeof(T) * (size_stride > size_dense ? size_stride : size_dense);
}
size_t GetSize(bool duplicate_inputs) const {
size_t size = GetSizeC();
if (duplicate_inputs) {
size += GetSizeA();
size += GetSizeB();
}
return size;
}
GemmStridedBatchedParams* DeepCopy(bool duplicate_inputs) const {
GemmStridedBatchedParams* copy = new GemmStridedBatchedParams(*this);
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t c_size = GetSizeC();
copy->c = static_cast<C_Dtype*>(c10::cuda::CUDACachingAllocator::raw_alloc(c_size));
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
if (duplicate_inputs) {
size_t a_size = GetSizeA();
size_t b_size = GetSizeB();
// NOLINTNEXTLINE(*const-cast*)
copy->a = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(a_size));
// NOLINTNEXTLINE(*const-cast*)
copy->b = static_cast<const T*>(c10::cuda::CUDACachingAllocator::raw_alloc(b_size));
copy->duplicate_inputs_ = true;
}
return copy;
}
// only call on object returned by DeepCopy
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
if (duplicate_inputs_) {
// NOLINTNEXTLINE(*const-cast*)
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(a));
// NOLINTNEXTLINE(*const-cast*)
c10::cuda::CUDACachingAllocator::raw_delete(const_cast<T*>(b));
}
}
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
auto* ctx = getTuningContext();
auto cfg = ctx->GetNumericalCheckConfig();
auto c_dtype = c10::CppTypeToScalarType<C_Dtype>::value;
return detail::NumericalCheck(c_dtype, c, other->c, GetSizeC()/sizeof(T), cfg) ? OK : FAIL;
}
char transa{};
char transb{};
int64_t m{};
int64_t n{};
int64_t k{};
at::opmath_type<T> alpha{};
const T* a{};
int64_t lda{};
int64_t stride_a{};
const T* b{};
int64_t ldb{};
int64_t stride_b{};
at::opmath_type<T> beta;
C_Dtype* c{};
int64_t ldc{};
int64_t stride_c{};
int64_t batch{};
private:
bool duplicate_inputs_{false};
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free