DispatchStub Class — pytorch Architecture
Architecture documentation for the DispatchStub class in DispatchStub.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/DispatchStub.h lines 216–330
template <typename rT, typename T, typename... Args>
struct DispatchStub<rT (*)(Args...), T> {
using FnPtr = rT (*) (Args...);
DispatchStub() = default;
DispatchStub(const DispatchStub&) = delete;
DispatchStub& operator=(const DispatchStub&) = delete;
private:
FnPtr get_call_ptr(const c10::DeviceType device_type) {
return reinterpret_cast<FnPtr>(
impl.get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, reinterpret_cast<void*>(SVE256)
#endif
)
);
}
public:
template <typename... ArgTypes>
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
FnPtr call_ptr = get_call_ptr(device_type);
return (*call_ptr)(std::forward<ArgTypes>(args)...);
}
void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
#if defined(USE_XPU)
void set_xpu_dispatch_ptr(FnPtr fn_ptr){
impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
#endif
void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
impl.hpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_mps_dispatch_ptr(FnPtr fn_ptr) {
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
}
// Returns true if the dispatcher has a kernel registered for this device
// type.
bool is_device_supported(const c10::DeviceType device_type) {
auto result = impl.try_get_call_ptr(device_type
, reinterpret_cast<void*>(DEFAULT)
#ifdef HAVE_AVX512_CPU_DEFINITION
, reinterpret_cast<void*>(AVX512)
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, reinterpret_cast<void*>(AVX2)
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, reinterpret_cast<void*>(VSX)
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
, reinterpret_cast<void*>(ZVECTOR)
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
, reinterpret_cast<void*>(SVE256)
#endif
);
if (std::holds_alternative<ErrorType>(result)){
return false;
}
return true;
}
static TORCH_API FnPtr DEFAULT;
#ifdef HAVE_AVX512_CPU_DEFINITION
static TORCH_API FnPtr AVX512;
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
static TORCH_API FnPtr AVX2;
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
static TORCH_API FnPtr VSX;
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
static TORCH_API FnPtr ZVECTOR;
#endif
#ifdef HAVE_SVE256_CPU_DEFINITION
static TORCH_API FnPtr SVE256;
#endif
private:
DispatchStubImpl impl;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free