primitive_ext Class — pytorch Architecture
Architecture documentation for the primitive_ext class in DnnlExt.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/mkldnn/xpu/detail/DnnlExt.h lines 31–303
class primitive_ext : public primitive {
static constexpr int max_args = 12;
public:
primitive_ext(const primitive& base) : primitive(base) {}
primitive_ext(primitive&& base) : primitive(std::move(base)) {}
/// Returns a memory descriptor.
///
/// @note
/// There are also convenience methods
/// #dnnl::primitive_desc_base::src_desc(),
/// #dnnl::primitive_desc_base::dst_desc(), and others.
///
/// @param what The kind of parameter to query; can be
/// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
/// @param idx Index of the parameter. For example, convolution bias can
/// be queried with what = #dnnl::query::weights_md and idx = 1.
/// @returns The requested memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// parameter of the specified kind or index.
const_dnnl_memory_desc_t query_md(query what, int idx = 0) const {
std::vector<query> valid_q{
query::src_md,
query::diff_src_md,
query::weights_md,
query::diff_weights_md,
query::dst_md,
query::diff_dst_md,
query::workspace_md,
query::scratchpad_md,
query::exec_arg_md};
if (!std::any_of(valid_q.cbegin(), valid_q.cend(), [=](query q) {
return what == q;
}))
DNNL_THROW_ERROR(
dnnl_invalid_arguments, "memory descriptor query is invalid");
const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
this->get_primitive_desc(), dnnl::convert_to_c(what), idx);
return cdesc ? cdesc : nullptr;
}
/// Returns a source memory descriptor.
/// @param idx Source index.
/// @returns Source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// source parameter with index @p idx.
const_dnnl_memory_desc_t src_desc(int idx) const {
return query_md(query::src_md, idx);
}
/// Returns a destination memory descriptor.
/// @param idx Destination index.
/// @returns Destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// destination parameter with index @p idx.
const_dnnl_memory_desc_t dst_desc(int idx) const {
return query_md(query::dst_md, idx);
}
/// Returns a weights memory descriptor.
/// @param idx Weights index.
/// @returns Weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// weights parameter with index @p idx.
const_dnnl_memory_desc_t weights_desc(int idx) const {
return query_md(query::weights_md, idx);
}
/// Returns a diff source memory descriptor.
/// @param idx Diff source index.
/// @returns Diff source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff source parameter with index @p idx.
const_dnnl_memory_desc_t diff_src_desc(int idx) const {
return query_md(query::diff_src_md, idx);
}
/// Returns a diff destination memory descriptor.
/// @param idx Diff destination index.
/// @returns Diff destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff destination parameter with index @p idx.
const_dnnl_memory_desc_t diff_dst_desc(int idx) const {
return query_md(query::diff_dst_md, idx);
}
/// Returns a diff weights memory descriptor.
/// @param idx Diff weights index.
/// @returns Diff weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff weights parameter with index @p idx.
const_dnnl_memory_desc_t diff_weights_desc(int idx) const {
return query_md(query::diff_weights_md, idx);
}
const_dnnl_memory_desc_t exec_arg_desc(int idx) const {
return query_md(query::exec_arg_md, idx);
}
// Separate versions without the index argument for documentation
// purposes.
/// Returns a source memory descriptor.
/// @returns Source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// source parameter.
const_dnnl_memory_desc_t src_desc() const {
return src_desc(0);
}
/// Returns a destination memory descriptor.
/// @returns Destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// destination parameter.
const_dnnl_memory_desc_t dst_desc() const {
return dst_desc(0);
}
/// Returns a weights memory descriptor.
/// @returns Weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// weights parameter.
const_dnnl_memory_desc_t weights_desc() const {
return weights_desc(0);
}
/// Returns a diff source memory descriptor.
/// @returns Diff source memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff source memory with.
const_dnnl_memory_desc_t diff_src_desc() const {
return diff_src_desc(0);
}
/// Returns a diff destination memory descriptor.
/// @returns Diff destination memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff destination parameter.
const_dnnl_memory_desc_t diff_dst_desc() const {
return diff_dst_desc(0);
}
/// Returns a diff weights memory descriptor.
/// @returns Diff weights memory descriptor.
/// @returns A zero memory descriptor if the primitive does not have a
/// diff weights parameter.
const_dnnl_memory_desc_t diff_weights_desc() const {
return diff_weights_desc(0);
}
/// Returns the workspace memory descriptor.
/// @returns Workspace memory descriptor.
/// @returns A zero memory descriptor if the primitive does not require
/// workspace parameter.
const_dnnl_memory_desc_t workspace_desc() const {
return query_md(query::workspace_md, 0);
}
/// Returns the scratchpad memory descriptor.
/// @returns scratchpad memory descriptor.
/// @returns A zero memory descriptor if the primitive does not require
/// scratchpad parameter.
/// @sa @ref dev_guide_attributes_scratchpad
const_dnnl_memory_desc_t scratchpad_desc() const {
return query_md(query::scratchpad_md, 0);
}
inline memory make_memory(
const_dnnl_memory_desc_t md_t,
const engine& aengine,
void* handle = DNNL_MEMORY_ALLOCATE) const {
sycl_interop::memory_kind kind = dnnl::sycl_interop::memory_kind::usm;
dnnl_memory_t c_memory;
error::wrap_c_api(
dnnl_sycl_interop_memory_create(
&c_memory, md_t, aengine.get(), convert_to_c(kind), handle),
"could not create a memory");
return memory(c_memory);
}
memory make_src(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
const {
return make_memory(src_desc(), aengine, handle);
}
memory make_weight(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
const {
return make_memory(weights_desc(), aengine, handle);
}
memory make_bias(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
const {
return make_memory(weights_desc(1), aengine, handle);
}
memory make_dst(const engine& aengine, void* handle = DNNL_MEMORY_ALLOCATE)
const {
return make_memory(dst_desc(), aengine, handle);
}
memory make_scratchpad(
const engine& aengine,
void* handle = DNNL_MEMORY_ALLOCATE) const {
return make_memory(scratchpad_desc(), aengine, handle);
}
size_t get_scratchpad_size() const {
return dnnl_memory_desc_get_size(scratchpad_desc());
}
memory make_args(int arg_class, const engine& aengine, void* handle) const {
switch (arg_class) {
case DNNL_ARG_SRC:
return make_src(aengine, handle);
case DNNL_ARG_WEIGHTS:
return make_weight(aengine, handle);
case DNNL_ARG_SCRATCHPAD:
return make_scratchpad(aengine, handle);
case DNNL_ARG_DST:
return make_dst(aengine, handle);
case DNNL_ARG_BIAS:
return make_bias(aengine, handle);
default:
TORCH_INTERNAL_ASSERT(
false, "unsupported argument class for primitive_ext");
}
}
template <typename M>
void set_attribute(int slot, int arg_class, void* handle, M constructor) {
if (mem_arg_cache[slot])
mem_arg_cache[slot].set_data_handle(handle);
else {
mem_arg_cache[slot] = constructor();
c_args[slot].arg = arg_class;
c_args[slot].memory = mem_arg_cache[slot].get();
}
}
sycl::event execute(
const stream& astream,
const engine& aengine,
std::vector<std::pair<int, void*>>&& handles,
int slot_off = 2) {
auto off = slot_off;
for (const auto& p : handles) {
auto& m_arg = mem_arg_cache[off];
if (m_arg)
m_arg.set_data_handle(p.second);
else {
m_arg = make_args(p.first, aengine, p.second);
c_args[off].arg = p.first;
c_args[off].memory = m_arg.get();
}
++off;
}
sycl::event return_event;
std::vector<sycl::event> deps{};
error::wrap_c_api(
dnnl_sycl_interop_primitive_execute(
this->get(), astream.get(), off, c_args, &deps, &return_event),
"could not execute a primitive");
return return_event;
}
private:
memory mem_arg_cache[max_args];
dnnl_exec_arg_t c_args[max_args];
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free