post_op Class — pytorch Architecture
Architecture documentation for the post_op class in qlinear.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/qlinear.cpp lines 800–893
template <PostOps post_op>
at::Tensor PackedLinearWeightsOnednn::apply_impl(
at::Tensor input,
double output_scale,
int64_t output_zero_point,
torch::List<at::Scalar> post_op_args) {
const int64_t dim = input.dim();
TORCH_CHECK(
dim != 0,
"qlinear (ONEDNN): input dim should be at least 1, but got 0");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8 || input.scalar_type() == c10::ScalarType::QInt8,
"qlinear (ONEDNN): data type of input should be QUInt8 or QInt8.");
auto is_input_qint8 = input.scalar_type() == c10::ScalarType::QInt8;
auto input_contig = input.expect_contiguous();
auto& w = *weight_;
auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
auto input_dims = {M, K};
auto input_data_type = is_input_qint8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::u8;
auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
ideep::attr_t op_attr = ideep::attr_t();
if (post_op == Relu) {
op_attr = ideep::attr_t::fuse_relu();
} else if (post_op == LeakyRelu) {
op_attr = ideep::attr_t::fuse_relu(/*scale=*/1.0f, /*alpha=*/post_op_args.get(0).to<double>());
} else if (post_op == Tanh) {
op_attr = ideep::attr_t::fuse_tanh();
}
ideep::tensor x(input_desc, input_contig->data_ptr());
auto dst_dims = {M, N};
double input_scale = input.q_scale();
int64_t input_zero_point = input.q_zero_point();
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input_scale);
const ideep::scale_t& weights_scales = w.get_scale();
// Scales of ONEDNN and PyTorch are reciprocal
const ideep::scale_t& dst_scales = ideep::scale_t(1, 1.0/output_scale);
const ideep::zero_point_t& src_zero_point = ideep::zero_point_t(1, input_zero_point);
const ideep::zero_point_t& dst_zero_point = ideep::zero_point_t(1, output_zero_point);
// Compute: Use ideep::matmul_forward to support asymmetric quantization
// Allocate output Tensor
at::Tensor output = at::_empty_affine_quantized(
dst_dims,
at::device(c10::kCPU).dtype(is_input_qint8 ? c10::kQInt8 : c10::kQUInt8),
output_scale,
output_zero_point);
if (output.numel() == 0) {
return output;
}
auto output_ideep_data_type = is_input_qint8 ? ideep::tensor::data_type::s8 : ideep::tensor::data_type::u8;
auto ideep_lowp_kind = is_input_qint8 ? ideep::s8s8 : ideep::u8s8;
ideep::tensor y({dst_dims, output_ideep_data_type,
{output.strides().cbegin(), output.strides().cend()}},
output.data_ptr());
bool with_bias = bias_.has_value();
if (with_bias) {
// Bias might be modified outside (e.g. by quantization bias correction).
// If so, update the prepacked bias as well.
if (bias_.value().get_data_handle() != orig_bias_.value().data_ptr()) {
bias_.value().init(bias_.value().get_desc(), orig_bias_.value().data_ptr());
}
}
const auto& b = with_bias ? bias_.value() : ideep::tensor();
// Primitive cache is initialized when called for the first time
// and won't be updated afterwards.
int num_threads = at::get_num_threads();
PrimitiveCacheKey cache_key = std::make_tuple(
input_scale, input_zero_point, input_dims, output_scale, output_zero_point, num_threads, /*accum scale*/1.0, /*accum zero point*/0);
c10::call_once(*cache_initialized_flag, [&](){
LinearParams params;
ideep::matmul_forward::prepare</*is_dynamic=*/false>(
params, x, w, b, y,
src_scales, weights_scales, dst_scales,
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr,
output_ideep_data_type,
ideep_lowp_kind);
get_cache() = LinearPrimitiveCache(cache_key, params);
w = w.reorder_if_differ_in(params.pd.weights_desc());
});
if (get_cache().hit(cache_key)) {
LinearParams& params = get_cache().get_param();
ideep::matmul_forward::compute<false, false>(params, x, w, b, y);
} else {
ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
dst_scales, src_zero_point, dst_zero_point,
1.0f, 1.0f, op_attr,
output_ideep_data_type,
ideep_lowp_kind);
}
auto out_sizes = input.sizes().vec();
out_sizes.back() = N;
if (output.sizes().vec() == out_sizes)
return output;
return output.reshape(out_sizes);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free