Home / Class/ AssertVectorized Class — pytorch Architecture

AssertVectorized Class — pytorch Architecture

Architecture documentation for the AssertVectorized class in vec_test_all_types.h from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/test/vec_test_all_types.h lines 817–935

template <typename T>
class AssertVectorized
{
public:
    AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0)
        : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), argSize(1)
    {
    }
    AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1)
        : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), argSize(2)
    {
    }
    AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1, const T& input2)
        : additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), arg2(input2), argSize(3)
    {
    }
    AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual) : additionalInfo(info), testSeed(seed), exp(expected), act(actual)
    {
    }
    AssertVectorized(const std::string& info, const T& expected, const T& actual) : additionalInfo(info), exp(expected), act(actual), hasSeed(false)
    {
    }

    std::string getDetail(int index) const
    {
        using UVT = UvalueType<T>;
        std::stringstream stream;
        stream.precision(std::numeric_limits<UVT>::max_digits10);
        stream << "Failure Details:\n";
        stream << additionalInfo << '\n';
        if (hasSeed)
        {
            stream << "Test Seed to reproduce: " << testSeed << '\n';
        }
        if (argSize > 0)
        {
            stream << "Arguments:\n";
            stream << "#\t " << arg0 << '\n';
            if (argSize == 2)
            {
                stream << "#\t " << arg1 << '\n';
            }
            if (argSize == 3)
            {
                stream << "#\t " << arg2 << '\n';
            }
        }
        stream << "Expected:\n#\t" << exp << "\nActual:\n#\t" << act;
        stream << "\nFirst mismatch Index: " << index;
        return stream.str();
    }

    bool check(bool bitwise = false, bool checkWithTolerance = false, ValueType<T> toleranceEps = {}) const
    {
        using UVT = UvalueType<T>;
        using BVT = BitType<UVT>;
        UVT absErr = correctEpsilon(toleranceEps);
        constexpr int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount;
        constexpr int unitStorageCount = VecTypeHelper<T>::unitStorageCount;
        CACHE_ALIGN UVT expArr[sizeX];
        CACHE_ALIGN UVT actArr[sizeX];
        exp.store(expArr);
        act.store(actArr);
        if (bitwise)
        {
            for (const auto i : c10::irange(sizeX)) {
                BVT b_exp = c10::bit_cast<BVT>(expArr[i]);
                BVT b_act = c10::bit_cast<BVT>(actArr[i]);
                EXPECT_EQ(b_exp, b_act) << getDetail(i / unitStorageCount);
                if (::testing::Test::HasFailure())
                    return true;
            }
        }
        else if (checkWithTolerance)
        {
            for (const auto i : c10::irange(sizeX)) {
                EXPECT_EQ(nearlyEqual<UVT>(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << '\n' << getDetail(i / unitStorageCount);
                if (::testing::Test::HasFailure())
                    return true;
            }
        }
        else
        {
            for (const auto i : c10::irange(sizeX)) {
                if constexpr (std::is_same_v<UVT, float>)
                {
                    if (!check_both_nan(expArr[i], actArr[i])) {
                        EXPECT_FLOAT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
                    }
                }
                else if constexpr (std::is_same_v<UVT, double>)
                {
                    if (!check_both_nan(expArr[i], actArr[i]))
                    {
                        EXPECT_DOUBLE_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
                    }
                }
                else
                {
                    EXPECT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
                }
                if (::testing::Test::HasFailure())
                    return true;
            }
        }
        return false;
    }

private:
    std::string additionalInfo;
    TestSeed testSeed;
    T exp;
    T act;
    T arg0;
    T arg1;
    T arg2;
    int argSize = 0;
    bool hasSeed = true;
};

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free