MetalKernelFunction Class — pytorch Architecture
Architecture documentation for the MetalKernelFunction class in MetalShaderLibrary.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mps/MetalShaderLibrary.h lines 53–98
class MetalKernelFunction {
public:
MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_);
~MetalKernelFunction();
MetalKernelFunction(MetalKernelFunction&) = delete;
// Shader properties
uint64_t getMaxThreadsPerThreadgroup() const;
uint64_t getThreadExecutionWidth() const;
uint64_t getStaticThreadGroupMemoryLength() const;
void runCommandBlock(std::function<void(void)> f);
// Methods below should be called from runCommandBlock function
void startEncoding();
void setArg(unsigned idx, const at::TensorBase& t);
void setArg(unsigned idx, const void* ptr, uint64_t size);
void setErrorBufferIndex(unsigned idx);
template <
typename T,
typename = std::enable_if_t<
std::is_integral_v<T> || std::is_same_v<T, float> ||
(std::is_class_v<T> && std::is_trivially_copyable_v<T> &&
!detail::has_size_type_v<T>)>>
inline void setArg(unsigned idx, const T val) {
setArg(idx, &val, sizeof(T));
}
template <
typename Container,
typename = std::enable_if_t<detail::has_size_type_v<Container>>>
inline void setArg(unsigned idx, const Container& values) {
setArg(
idx,
values.data(),
values.size() * sizeof(typename Container::value_type));
}
void dispatch(
uint64_t length,
std::optional<uint64_t> groupSize = std::nullopt);
void dispatch(
c10::ArrayRef<uint64_t> length,
c10::OptionalArrayRef<uint64_t> groupSize = std::nullopt);
private:
MTLComputePipelineState_t cps;
MTLFunction_t func;
MTLComputeCommandEncoder_t encoder = nullptr;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free