kNumThreads Class — pytorch Architecture
Architecture documentation for the kNumThreads class in kernel_backward.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h lines 67–163
template <typename FragmentType, int32_t kNumThreads>
struct GmemTile {
/*
Helper functions to efficient store/load RF to gmem
GEMM accumulators have a particular format on A100, and
it takes some compute/shared-memory to rearrange them to
a RowMajor or ColumnMajor format in global memory through
an Epilogue. The same complexity goes for loading into RF.
This class loads/stores RF as they are, and can be used for
efficient accumulation across gemms for instance:
```
GmemTile tile;
for (int i = 0; i < N; ++i) {
// ...
Fragment accum;
if (i == 0) {
accum.clear();
} else {
tile.load(accum);
}
mma(accum, ...);
if (i < N-1) {
// Store for next GEMM
tile.store(accum);
} else {
// Store in tensor (eg RowMajor)
epilogue(accum);
}
// ...
}
```
*/
// 128bits per thread
using AccessType = cutlass::Array<float, 4>;
static constexpr int32_t kBytes = sizeof(AccessType);
static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
static constexpr int32_t kNumIters =
FragmentType::kElements / AccessType::kElements;
static constexpr int32_t kElementsStored =
kNumThreads * FragmentType::kElements;
static_assert(
FragmentType::kElements % AccessType::kElements == 0,
"fragment not aligned on 128 bits");
float* ptr;
CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
ptr + thread_id * AccessType::kElements + i * kStride);
AccessType sub_fragment;
cutlass::arch::global_load<AccessType, kBytes>(
sub_fragment, gmem_ptr, true);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccessType::kElements; ++j) {
fragment[i * AccessType::kElements + j] = sub_fragment[j];
}
}
}
CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
ptr + thread_id * AccessType::kElements + i * kStride);
AccessType sub_fragment;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccessType::kElements; ++j) {
sub_fragment[j] = fragment[i * AccessType::kElements + j];
}
cutlass::arch::global_store<AccessType, kBytes>(
sub_fragment, gmem_ptr, true);
}
}
CUTLASS_DEVICE void storeAtomicAdd(
FragmentType const& fragment,
int thread_id) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccessType::kElements; ++j) {
float val = fragment[i * AccessType::kElements + j];
float* ptr = gmem_ptr + j;
atomicAdd(ptr, val);
}
}
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free