Conv2dOpContext Class — pytorch Architecture
Architecture documentation for the Conv2dOpContext class in MetalPrepackOpContext.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/metal/MetalPrepackOpContext.h lines 18–121
class Conv2dOpContext : public torch::jit::CustomClassHolder {
public:
SerializationTypeConv2dPrePack pack() {
return std::make_tuple(
weight_,
bias_,
stride_,
padding_,
dilation_,
groups_,
output_min_,
output_max_);
}
Conv2dOpContext() = delete;
Conv2dOpContext(
at::Tensor&& weight,
std::optional<at::Tensor>&& bias,
std::vector<int64_t> stride,
std::vector<int64_t> padding,
std::vector<int64_t> dilation,
int64_t groups,
std::optional<Scalar> output_min,
std::optional<Scalar> output_max)
: weight_(std::move(weight)),
bias_(std::move(bias)),
stride_(std::move(stride)),
padding_(std::move(padding)),
dilation_(std::move(dilation)),
groups_(groups),
output_min_(std::move(output_min)),
output_max_(std::move(output_max)) {}
~Conv2dOpContext() override {
if (releaseCallback_) {
releaseCallback_(conv2dOp_);
}
}
void release_resources() override {
if (releaseCallback_) {
releaseCallback_(conv2dOp_);
}
}
const Tensor& get_weight() const {
return weight_;
}
const std::optional<Tensor>& get_bias() const {
return bias_;
}
const std::vector<int64_t>& get_stride() const {
return stride_;
}
const std::vector<int64_t>& get_padding() const {
return padding_;
}
const std::vector<int64_t>& get_dilation() const {
return dilation_;
}
int64_t get_groups() const {
return groups_;
}
const std::optional<Scalar>& get_output_min() const {
return output_min_;
}
const std::optional<Scalar>& get_output_max() const {
return output_max_;
}
void set_conv2dOpPtr(void* ptr) {
conv2dOp_ = ptr;
}
void* get_conv2dOpPtr() const {
return conv2dOp_;
}
void set_releaseCallback(const std::function<void(void*)>& func) {
releaseCallback_ = func;
}
std::function<void(void*)>& get_releaseCallback() {
return releaseCallback_;
}
private:
Tensor weight_;
std::optional<Tensor> bias_;
std::vector<int64_t> stride_;
std::vector<int64_t> padding_;
std::vector<int64_t> dilation_;
int64_t groups_;
std::optional<Scalar> output_min_;
std::optional<Scalar> output_max_;
std::function<void(void*)> releaseCallback_ = nullptr;
void* conv2dOp_ = nullptr; // reserved to hold MPSCNNConv2dOp objects
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free