AssertVectorized Class — pytorch Architecture
Architecture documentation for the AssertVectorized class in vec_test_all_types.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/test/vec_test_all_types.h lines 817–935
template <typename T>
class AssertVectorized
{
public:
AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0)
: additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), argSize(1)
{
}
AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1)
: additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), argSize(2)
{
}
AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual, const T& input0, const T& input1, const T& input2)
: additionalInfo(info), testSeed(seed), exp(expected), act(actual), arg0(input0), arg1(input1), arg2(input2), argSize(3)
{
}
AssertVectorized(const std::string& info, TestSeed seed, const T& expected, const T& actual) : additionalInfo(info), testSeed(seed), exp(expected), act(actual)
{
}
AssertVectorized(const std::string& info, const T& expected, const T& actual) : additionalInfo(info), exp(expected), act(actual), hasSeed(false)
{
}
std::string getDetail(int index) const
{
using UVT = UvalueType<T>;
std::stringstream stream;
stream.precision(std::numeric_limits<UVT>::max_digits10);
stream << "Failure Details:\n";
stream << additionalInfo << '\n';
if (hasSeed)
{
stream << "Test Seed to reproduce: " << testSeed << '\n';
}
if (argSize > 0)
{
stream << "Arguments:\n";
stream << "#\t " << arg0 << '\n';
if (argSize == 2)
{
stream << "#\t " << arg1 << '\n';
}
if (argSize == 3)
{
stream << "#\t " << arg2 << '\n';
}
}
stream << "Expected:\n#\t" << exp << "\nActual:\n#\t" << act;
stream << "\nFirst mismatch Index: " << index;
return stream.str();
}
bool check(bool bitwise = false, bool checkWithTolerance = false, ValueType<T> toleranceEps = {}) const
{
using UVT = UvalueType<T>;
using BVT = BitType<UVT>;
UVT absErr = correctEpsilon(toleranceEps);
constexpr int sizeX = VecTypeHelper<T>::holdCount * VecTypeHelper<T>::unitStorageCount;
constexpr int unitStorageCount = VecTypeHelper<T>::unitStorageCount;
CACHE_ALIGN UVT expArr[sizeX];
CACHE_ALIGN UVT actArr[sizeX];
exp.store(expArr);
act.store(actArr);
if (bitwise)
{
for (const auto i : c10::irange(sizeX)) {
BVT b_exp = c10::bit_cast<BVT>(expArr[i]);
BVT b_act = c10::bit_cast<BVT>(actArr[i]);
EXPECT_EQ(b_exp, b_act) << getDetail(i / unitStorageCount);
if (::testing::Test::HasFailure())
return true;
}
}
else if (checkWithTolerance)
{
for (const auto i : c10::irange(sizeX)) {
EXPECT_EQ(nearlyEqual<UVT>(expArr[i], actArr[i], absErr), true) << expArr[i] << "!=" << actArr[i] << '\n' << getDetail(i / unitStorageCount);
if (::testing::Test::HasFailure())
return true;
}
}
else
{
for (const auto i : c10::irange(sizeX)) {
if constexpr (std::is_same_v<UVT, float>)
{
if (!check_both_nan(expArr[i], actArr[i])) {
EXPECT_FLOAT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
}
}
else if constexpr (std::is_same_v<UVT, double>)
{
if (!check_both_nan(expArr[i], actArr[i]))
{
EXPECT_DOUBLE_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
}
}
else
{
EXPECT_EQ(expArr[i], actArr[i]) << getDetail(i / unitStorageCount);
}
if (::testing::Test::HasFailure())
return true;
}
}
return false;
}
private:
std::string additionalInfo;
TestSeed testSeed;
T exp;
T act;
T arg0;
T arg1;
T arg2;
int argSize = 0;
bool hasSeed = true;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free