fn Class — pytorch Architecture
Architecture documentation for the fn class in NestedTensorUtils.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/nested/NestedTensorUtils.h lines 278–334
template <class F, class A, class... Args>
class _map<F, A, c10::guts::typelist::typelist<Args...>> {
public:
static A function_one(const F& fn, const Args&... nested_node) {
return fn(nested_node...);
}
static NestedNode<A> function(
const F& fn,
const NestedNode<Args>&... nested_node) {
size_t degree = 0;
bool all_leaf = true;
c10::guts::tuple_map(
std::forward_as_tuple(nested_node...), [&all_leaf, °ree](auto n) {
all_leaf = all_leaf && (n.is_leaf());
if (degree > 1 && n.degree() > 1) {
TORCH_CHECK(
degree == n.degree(), "NestedNodes must match in degree.");
}
if (n.degree() > degree) {
degree = n.degree();
}
return nullptr;
});
// All NestedNodes just wrap regular objects.
if (all_leaf) {
return NestedNode<A>(std::forward<F>(fn)(nested_node.payload()...));
}
// Some NestedNodes wrap regular Tensors, some NestedTensors and some other
// types.
std::vector<A> result;
for (size_t i = 0; i < degree; i++) {
auto children = c10::guts::tuple_map(
std::forward_as_tuple(nested_node...), [&i](auto a) {
static_assert(
c10::guts::is_instantiation_of<NestedNode, decltype(a)>::value,
"Internal error.");
// Broadcast regular arguments across NestedTensor constituents.
// This could be a Tensor, integer or anything else really.
if (a.is_leaf()) {
return a.payload();
}
// Broadcast NestedTensors with one constituent.
if (a.degree() == 1 && !a.is_leaf()) {
return a.children(0);
}
TORCH_CHECK(a.degree() > 0, "Internal assert.");
return a.children(i);
});
std::apply(
[&result, &fn](Args... filtered) {
result.emplace_back(function_one(fn, filtered...));
},
std::move(children));
}
return NestedNode<A>(std::move(result));
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free