Home / Class/ test_binary Class — pytorch Architecture

test_binary Class — pytorch Architecture

Architecture documentation for the test_binary class in vec_test_all_types.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/test/vec_test_all_types.h lines 995–1094

template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t>
void test_binary(
    std::string testNameInfo,
    Op1 expectedFunction,
    Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) {
    using vec_type = T;
    using VT = ValueType<T>;
    using UVT = UvalueType<T>;
    constexpr int el_count = vec_type::size();
    CACHE_ALIGN VT vals0[el_count];
    CACHE_ALIGN VT vals1[el_count];
    CACHE_ALIGN VT expected[el_count];
    [[maybe_unused]] CACHE_ALIGN VT expectedWithLeftScalar[el_count];
    [[maybe_unused]] CACHE_ALIGN VT expectedWithRightScalar[el_count];
    [[maybe_unused]] VT scalar0;
    [[maybe_unused]] VT scalar1;
    bool bitwise = testCase.isBitwise();
    UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min();
    UVT default_end = std::numeric_limits<UVT>::max();
    auto domains = testCase.getDomains();
    auto domains_size = domains.size();
    auto test_trials = testCase.getTrialCount();
    int trialCount = getTrialCount<UVT>(test_trials, domains_size);
    TestSeed seed = testCase.getTestSeed();
    uint64_t changeSeedBy = 0;
    constexpr bool kCanUseScalar = std::is_invocable_v<Op2, VT, T> && std::is_invocable_v<Op2, T, VT>;
    for (const CheckWithinDomains<UVT>& dmn : testCase.getDomains()) {
        size_t dmn_argc = dmn.ArgsDomain.size();
        UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start;
        UVT end0 = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end;
        UVT start1 = dmn_argc > 1 ? dmn.ArgsDomain[1].start : default_start;
        UVT end1 = dmn_argc > 1 ? dmn.ArgsDomain[1].end : default_end;
        ValueGen<VT> generator0(start0, end0, seed.add(changeSeedBy));
        ValueGen<VT> generator1(start1, end1, seed.add(changeSeedBy + 1));
        for ([[maybe_unused]] const auto trial : c10::irange(trialCount)) {
          for (const auto k : c10::irange(el_count)) {
            vals0[k] = generator0.get();
            vals1[k] = generator1.get();
            if (k == 0) {
              scalar0 = vals0[0];
              scalar1 = vals1[0];
            }
            call_filter(filter, vals0[k], vals1[k]);
            if constexpr (kCanUseScalar) {
              call_filter(filter, vals0[k], scalar1);
              call_filter(filter, scalar0, vals1[k]);
            }
          }
          for (const auto k : c10::irange(el_count)) {
            // map operator
            expected[k] = expectedFunction(vals0[k], vals1[k]);
            if constexpr (kCanUseScalar) {
              expectedWithLeftScalar[k] = expectedFunction(scalar0, vals1[k]);
              expectedWithRightScalar[k] = expectedFunction(vals0[k], scalar1);
            }
          }
          // test
          auto input0 = vec_type::loadu(vals0);
          auto input1 = vec_type::loadu(vals1);
          auto actual = actualFunction(input0, input1);
          auto vec_expected = vec_type::loadu(expected);
          AssertVectorized<vec_type> vecAssert(
              testNameInfo, seed, vec_expected, actual, input0, input1);
          if (vecAssert.check(
                  bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) {
            return;
          }
          if constexpr (kCanUseScalar) {
            auto actualWithLeftScalar = actualFunction(scalar0, input1);
            auto actualWithRightScalar = actualFunction(input0, scalar1);
            auto vec_expectedWithLeftScalar = vec_type::loadu(expectedWithLeftScalar);
            auto vec_expectedWithRightScalar = vec_type::loadu(expectedWithRightScalar);
            AssertVectorized<vec_type> vecAssertWithLeftScalar(
                testNameInfo, seed, vec_expectedWithLeftScalar, actualWithLeftScalar, scalar0, input1);
            if (vecAssertWithLeftScalar.check(
                    bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) {
              return;
            }
            AssertVectorized<vec_type> vecAssertWithRightScalar(
                testNameInfo, seed, vec_expectedWithRightScalar, actualWithRightScalar, input0, scalar1);
            if (vecAssertWithRightScalar.check(
                    bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) {
              return;
            }
          }
        } // trial
        changeSeedBy += 1;
    }
    for (auto& custom : testCase.getCustomChecks()) {
        auto args = custom.Args;
        if (args.size() > 0) {
            auto input0 = vec_type{ args[0] };
            auto input1 = args.size() > 1 ? vec_type{ args[1] } : vec_type{ args[0] };
            auto actual = actualFunction(input0, input1);
            auto vec_expected = vec_type(custom.expectedResult);
            AssertVectorized<vec_type> vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1);
            if (vecAssert.check()) return;
        }
    }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free