spatial_dilated_max_pooling3d Class — pytorch Architecture
Architecture documentation for the spatial_dilated_max_pooling3d class in Pooling.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/quantized/cpu/Pooling.cpp lines 96–177
template <typename T>
void spatial_dilated_max_pooling3d(
const T* qxd,
int64_t nbatch,
int64_t iC, // input/output channels
int64_t iT,
int64_t iH,
int64_t iW, // input sizes
int64_t oT,
int64_t oH,
int64_t oW, // output sizes
int64_t kT,
int64_t kH,
int64_t kW, // kernel size
int64_t sT,
int64_t sH,
int64_t sW, // strides
int64_t pT,
int64_t pH,
int64_t pW, // padding
int64_t dT,
int64_t dH,
int64_t dW, // dilation
T* qyd) { // output arrays (data and max-index)
// TODO: Further optimize the performance suggested by @mingfeima. Parallel on NCTH and cache the output indices from W.
// Handle each bs
int64_t oC = iC;
int64_t parallel_dim = nbatch * iC;
at::parallel_for(0, parallel_dim, 0, [&](int64_t start, int64_t end) {
for (const auto p : c10::irange(start, end)) {
int64_t batch_idx = p / iC;
int64_t channel_idx = p - batch_idx * iC;
auto* iData = qxd + batch_idx * iC * iT * iH * iW;
auto* oData = qyd + batch_idx * oC * oT * oH * oW;
// Handle each Channel
int64_t time, row, col;
const T* i_p = iData + channel_idx * iT * iW * iH;
for (time = 0; time < oT; ++time) {
for (row = 0; row < oH; ++row) {
for (col = 0; col < oW; ++col) {
// Handle each output element
int64_t t_start = time * sT - pT;
int64_t h_start = row * sH - pH;
int64_t w_start = col * sW - pW;
int64_t t_end = std::min(t_start + (kT - 1) * dT + 1, iT);
int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH);
int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW);
while (t_start < 0)
t_start += dT;
while (h_start < 0)
h_start += dH;
while (w_start < 0)
w_start += dW;
// local pointers
T* o_p = oData + channel_idx * oT * oH * oW + time * oH * oW + row * oW + col;
// local max
auto max_val = std::numeric_limits<typename T::underlying>::lowest();
int64_t tcntr = 0; // center point
for (int64_t t = t_start; t < t_end; t += dT) {
for (int64_t y = h_start; y < h_end; y += dH) {
for (int64_t x = w_start; x < w_end; x += dW) {
tcntr = t * iH * iW + y * iW + x;
auto val = (i_p + tcntr)->val_;
if (val > max_val) {
max_val = val;
}
}
}
}
*o_p = T(max_val); // Output.
}
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free