stage_pack_weights Class — pytorch Architecture
Architecture documentation for the stage_pack_weights class in Mm.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/vulkan/ops/Mm.h lines 16–54
template <typename T>
void stage_pack_weights(
api::Context* const context,
vTensor& v_weight,
const Tensor& weight,
const int64_t src_kb_sz,
const int64_t src_kh_sz,
const int64_t src_kw_sz,
const int64_t dst_kh_sz,
const int64_t dst_kw_sz) {
const int64_t src_matrix_sz = src_kw_sz * src_kh_sz;
const int64_t dst_plane_sz = dst_kw_sz * dst_kh_sz;
const int64_t dst_matrix_sz = dst_plane_sz * 4;
const T* const src_weight_ptr = weight.const_data_ptr<T>();
api::StorageBuffer staging(context, api::kFloat, v_weight.gpu_numel());
{
api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
T* dst_weight_ptr = mapping.template data<T>();
memset(dst_weight_ptr, 0, v_weight.nbytes());
for (const auto src_b : c10::irange(src_kb_sz)) {
for (const auto src_h : c10::irange(src_kh_sz)) {
for (const auto src_w : c10::irange(src_kw_sz)) {
int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2);
int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2);
memcpy(
dst_weight_ptr + src_b * dst_matrix_sz +
dst_plane * dst_plane_sz + dst_index,
src_weight_ptr + src_b * src_matrix_sz + src_h * src_kw_sz +
src_w,
sizeof(T));
}
}
}
}
utils::pack_staging_to_vtensor(staging.buffer(), v_weight);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free