test_ternary Class — pytorch Architecture
Architecture documentation for the test_ternary class in vec_test_all_types.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/test/vec_test_all_types.h lines 1158–1215
template< typename T, typename Op1, typename Op2, typename Filter = std::nullptr_t>
void test_ternary(
std::string testNameInfo,
Op1 expectedFunction,
Op2 actualFunction, const TestingCase<T>& testCase, Filter filter = {}) {
using vec_type = T;
using VT = ValueType<T>;
using UVT = UvalueType<T>;
constexpr int el_count = vec_type::size();
CACHE_ALIGN VT vals0[el_count];
CACHE_ALIGN VT vals1[el_count];
CACHE_ALIGN VT vals2[el_count];
CACHE_ALIGN VT expected[el_count];
bool bitwise = testCase.isBitwise();
UVT default_start = std::is_floating_point_v<UVT> ? std::numeric_limits<UVT>::lowest() : std::numeric_limits<UVT>::min();
UVT default_end = std::numeric_limits<UVT>::max();
auto domains = testCase.getDomains();
auto domains_size = domains.size();
auto test_trials = testCase.getTrialCount();
int trialCount = getTrialCount<UVT>(test_trials, domains_size);
TestSeed seed = testCase.getTestSeed();
uint64_t changeSeedBy = 0;
for (const CheckWithinDomains<UVT>& dmn : testCase.getDomains()) {
size_t dmn_argc = dmn.ArgsDomain.size();
UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start;
UVT end0 = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end;
UVT start1 = dmn_argc > 1 ? dmn.ArgsDomain[1].start : default_start;
UVT end1 = dmn_argc > 1 ? dmn.ArgsDomain[1].end : default_end;
UVT start2 = dmn_argc > 2 ? dmn.ArgsDomain[2].start : default_start;
UVT end2 = dmn_argc > 2 ? dmn.ArgsDomain[2].end : default_end;
ValueGen<VT> generator0(start0, end0, seed.add(changeSeedBy));
ValueGen<VT> generator1(start1, end1, seed.add(changeSeedBy + 1));
ValueGen<VT> generator2(start2, end2, seed.add(changeSeedBy + 2));
for ([[maybe_unused]] const auto trial : c10::irange(trialCount)) {
for (const auto k : c10::irange(el_count)) {
vals0[k] = generator0.get();
vals1[k] = generator1.get();
vals2[k] = generator2.get();
call_filter(filter, vals0[k], vals1[k], vals2[k]);
// map operator
expected[k] = expectedFunction(vals0[k], vals1[k], vals2[k]);
}
// test
auto input0 = vec_type::loadu(vals0);
auto input1 = vec_type::loadu(vals1);
auto input2 = vec_type::loadu(vals2);
auto actual = actualFunction(input0, input1, input2);
auto vec_expected = vec_type::loadu(expected);
AssertVectorized<vec_type> vecAssert(
testNameInfo, seed, vec_expected, actual, input0, input1, input2);
if (vecAssert.check(
bitwise, dmn.CheckWithTolerance, dmn.ToleranceError))
return;
} // trial
changeSeedBy += 1;
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free