kSpatialDim Class — pytorch Architecture
Architecture documentation for the kSpatialDim class in OnednnUtils.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/OnednnUtils.h lines 190–310
template <int kSpatialDim = 2>
struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightsOnednn(
std::unique_ptr<ideep::tensor> weight,
std::optional<ideep::tensor> bias,
at::Tensor orig_weight,
std::optional<at::Tensor> orig_bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
uint8_t transpose)
: weight_(std::move(weight)),
bias_(std::move(bias)),
orig_weight_(std::move(orig_weight)),
orig_bias_(std::move(orig_bias)),
stride_(std::move(stride)),
padding_(std::move(padding)),
output_padding_(std::move(output_padding)),
dilation_(std::move(dilation)),
groups_(groups),
transpose_(transpose) {
cache_initialized_flag = std::make_unique<c10::once_flag>();
}
std::unique_ptr<ideep::tensor> weight_;
std::optional<ideep::tensor> bias_;
at::Tensor orig_weight_;
std::optional<at::Tensor> orig_bias_;
torch::List<int64_t> stride_;
torch::List<int64_t> padding_;
torch::List<int64_t> output_padding_;
torch::List<int64_t> dilation_;
int64_t groups_;
uint8_t transpose_;
at::Tensor apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_dynamic(
const at::Tensor& input,
bool reduce_range) override;
at::Tensor apply_add(
const at::Tensor& input,
const at::Tensor& accum,
double output_scale,
int64_t output_zero_point);
at::Tensor apply_add_relu(
const at::Tensor& input,
const at::Tensor& accum,
double output_scale,
int64_t output_zero_point);
std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
at::Tensor weight,
std::optional<at::Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose);
torch::List<int64_t> stride() const override {
return stride_;
}
torch::List<int64_t> padding() const override {
return padding_;
}
torch::List<int64_t> output_padding() const override {
return output_padding_;
}
torch::List<int64_t> dilation() const override {
return dilation_;
}
int64_t groups() const override {
return groups_;
}
bool transpose() const override {
return (bool)transpose_;
}
private:
ConvPrimitiveCache conv_prim_cache;
DeconvPrimitiveCache deconv_prim_cache;
std::unique_ptr<c10::once_flag> cache_initialized_flag;
template <bool ReluFused>
at::Tensor apply_impl(
const at::Tensor& input,
const std::optional<at::Tensor>& accum,
double output_scale,
int64_t output_zero_point);
ConvPrimitiveCache& get_conv_cache() {
assert(!transpose());
return conv_prim_cache;
}
DeconvPrimitiveCache& get_deconv_cache() {
assert(transpose());
return deconv_prim_cache;
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free