depth Class — pytorch Architecture
Architecture documentation for the depth class in MultiTensorApply.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mps/operations/MultiTensorApply.h lines 120–254
template <int depth, uint32_t kThreadGroupSize, typename encoder_func_t, typename... ArgTypes>
static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name,
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::TensorList state_steps,
encoder_func_t encode,
ArgTypes... args) {
const auto num_tensors = tensor_lists[0].size();
if (num_tensors == 0) {
return;
}
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
for (const auto& d : c10::irange(depth)) {
const auto scalar_type = tensor_lists[d][0].scalar_type();
TORCH_CHECK(scalar_type == kFloat || scalar_type == kHalf || scalar_type == kBFloat16,
"Only float, bfloat and half are supported");
}
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
// Remove comment for debugging
/*
mpsStream->addCompletedHandler(^(id<MTLCommandBuffer> cb) {
[cb.logs enumerateObjectsUsingBlock:^(NSString* log, NSUInteger idx, BOOL* stop) {
NSLog(@"MPSStream: %@", log);
}
];
});
*/
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto [fusedOptimizerPSO, fusedOptimizerFunc] = getFusedAdamCPLState(kernel_name);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(fusedOptimizerPSO, kernel_name, {tensor_lists[0]});
[computeEncoder setComputePipelineState:fusedOptimizerPSO];
// BufferIndex is the index in the kernel function
auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
int64_t tensor_loc = 0;
int64_t threadgroup_loc = 0;
MetadataArguments metadata_arguments;
for (const auto tensor_index : c10::irange(num_tensors)) {
// short-circuit to avoid adding empty tensors to tensorListMeta
if (tensor_lists[0][tensor_index].numel() == 0) {
continue;
}
for (const auto& d : c10::irange(depth)) {
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageRead | MTLResourceUsageWrite];
}
if (!state_steps.empty()) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
}
metadata_arguments.numels[tensor_loc] = tensor_lists[0][tensor_index].numel();
tensor_loc++;
const auto numel = tensor_lists[0][tensor_index].numel();
const auto chunks = numel / kChunkSize + (numel % kChunkSize != 0);
TORCH_CHECK(chunks > -1);
for (const auto& chunk : c10::irange(chunks)) {
metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = tensor_loc - 1;
metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk;
threadgroup_loc++;
const auto tensor_full = tensor_loc == kmaxTensors && chunk == chunks - 1;
// Reach the maximum threadgroups per dispatch
const auto blocks_full = threadgroup_loc == kmaxThreadGroups;
if (tensor_full || blocks_full) {
encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
// Reset
threadgroup_loc = 0;
if (chunk == chunks - 1) {
// last chunk
tensor_loc = 0;
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
} else {
// reuse the current tensor since the current one isn't done.
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
for (const auto& d : c10::irange(depth)) {
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageWrite | MTLResourceUsageRead];
}
if (!state_steps.empty()) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
}
tensor_loc = 1;
}
}
}
}
if (threadgroup_loc != 0) {
encode(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...);
MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1);
uint32_t maxThreadsPerGroup = [fusedOptimizerPSO maxTotalThreadsPerThreadgroup];
MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreadsPerGroup, kThreadGroupSize), 1, 1);
[computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize];
}
getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free