Home / Class/ fn Class — pytorch Architecture

fn Class — pytorch Architecture

Architecture documentation for the fn class in NestedTensorUtils.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/nested/NestedTensorUtils.h lines 278–334

template <class F, class A, class... Args>
class _map<F, A, c10::guts::typelist::typelist<Args...>> {
 public:
  static A function_one(const F& fn, const Args&... nested_node) {
    return fn(nested_node...);
  }
  static NestedNode<A> function(
      const F& fn,
      const NestedNode<Args>&... nested_node) {
    size_t degree = 0;
    bool all_leaf = true;
    c10::guts::tuple_map(
        std::forward_as_tuple(nested_node...), [&all_leaf, &degree](auto n) {
          all_leaf = all_leaf && (n.is_leaf());
          if (degree > 1 && n.degree() > 1) {
            TORCH_CHECK(
                degree == n.degree(), "NestedNodes must match in degree.");
          }
          if (n.degree() > degree) {
            degree = n.degree();
          }
          return nullptr;
        });
    // All NestedNodes just wrap regular objects.
    if (all_leaf) {
      return NestedNode<A>(std::forward<F>(fn)(nested_node.payload()...));
    }
    // Some NestedNodes wrap regular Tensors, some NestedTensors and some other
    // types.
    std::vector<A> result;
    for (size_t i = 0; i < degree; i++) {
      auto children = c10::guts::tuple_map(
          std::forward_as_tuple(nested_node...), [&i](auto a) {
            static_assert(
                c10::guts::is_instantiation_of<NestedNode, decltype(a)>::value,
                "Internal error.");
            // Broadcast regular arguments across NestedTensor constituents.
            // This could be a Tensor, integer or anything else really.
            if (a.is_leaf()) {
              return a.payload();
            }
            // Broadcast NestedTensors with one constituent.
            if (a.degree() == 1 && !a.is_leaf()) {
              return a.children(0);
            }
            TORCH_CHECK(a.degree() > 0, "Internal assert.");
            return a.children(i);
          });
      std::apply(
          [&result, &fn](Args... filtered) {
            result.emplace_back(function_one(fn, filtered...));
          },
          std::move(children));
    }
    return NestedNode<A>(std::move(result));
  }
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free