Home / Class/ Conv2dOpContext Class — pytorch Architecture

Conv2dOpContext Class — pytorch Architecture

Architecture documentation for the Conv2dOpContext class in MetalPrepackOpContext.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/metal/MetalPrepackOpContext.h lines 18–121

class Conv2dOpContext : public torch::jit::CustomClassHolder {
 public:
  SerializationTypeConv2dPrePack pack() {
    return std::make_tuple(
        weight_,
        bias_,
        stride_,
        padding_,
        dilation_,
        groups_,
        output_min_,
        output_max_);
  }
  Conv2dOpContext() = delete;
  Conv2dOpContext(
      at::Tensor&& weight,
      std::optional<at::Tensor>&& bias,
      std::vector<int64_t> stride,
      std::vector<int64_t> padding,
      std::vector<int64_t> dilation,
      int64_t groups,
      std::optional<Scalar> output_min,
      std::optional<Scalar> output_max)
      : weight_(std::move(weight)),
        bias_(std::move(bias)),
        stride_(std::move(stride)),
        padding_(std::move(padding)),
        dilation_(std::move(dilation)),
        groups_(groups),
        output_min_(std::move(output_min)),
        output_max_(std::move(output_max)) {}

  ~Conv2dOpContext() override {
    if (releaseCallback_) {
      releaseCallback_(conv2dOp_);
    }
  }

  void release_resources() override {
    if (releaseCallback_) {
      releaseCallback_(conv2dOp_);
    }
  }

  const Tensor& get_weight() const {
    return weight_;
  }

  const std::optional<Tensor>& get_bias() const {
    return bias_;
  }

  const std::vector<int64_t>& get_stride() const {
    return stride_;
  }

  const std::vector<int64_t>& get_padding() const {
    return padding_;
  }

  const std::vector<int64_t>& get_dilation() const {
    return dilation_;
  }

  int64_t get_groups() const {
    return groups_;
  }

  const std::optional<Scalar>& get_output_min() const {
    return output_min_;
  }

  const std::optional<Scalar>& get_output_max() const {
    return output_max_;
  }

  void set_conv2dOpPtr(void* ptr) {
      conv2dOp_ = ptr;
  }

  void* get_conv2dOpPtr() const {
    return conv2dOp_;
  }

  void set_releaseCallback(const std::function<void(void*)>& func) {
    releaseCallback_ = func;
  }

  std::function<void(void*)>& get_releaseCallback() {
     return releaseCallback_;
  }

  private:
    Tensor weight_;
    std::optional<Tensor> bias_;
    std::vector<int64_t> stride_;
    std::vector<int64_t> padding_;
    std::vector<int64_t> dilation_;
    int64_t groups_;
    std::optional<Scalar> output_min_;
    std::optional<Scalar> output_max_;
    std::function<void(void*)> releaseCallback_ = nullptr;
    void* conv2dOp_ = nullptr; // reserved to hold MPSCNNConv2dOp objects
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free