ArgTypeTestKernel Class — pytorch Architecture
Architecture documentation for the ArgTypeTestKernel class in op_registration_test.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/core/op_registration/op_registration_test.cpp lines 772–815
template<class InputType, class OutputType = InputType>
struct ArgTypeTestKernel final : OperatorKernel {
explicit ArgTypeTestKernel(InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output)
: input_(std::move(input)), inputExpectation_(std::move(inputExpectation)), output_(std::move(output)) {}
OutputType operator()(InputType input) const {
inputExpectation_(std::move(input));
return output_;
}
static void test(TestModernAndLegacyAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
test(TestModernAPI(), input, inputExpectation, output, outputExpectation, schema);
test(TestLegacyAPI(), input, inputExpectation, output, outputExpectation, schema);
}
static void test(TestModernAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
return test_([&] {
return c10::RegisterOperators().op("_test::my_op" + schema, c10::RegisterOperators::options().catchAllKernel<ArgTypeTestKernel>(input, inputExpectation, output));
}, input, inputExpectation, output, outputExpectation, schema);
}
static void test(TestLegacyAPI, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
return test_([&] {
return c10::RegisterOperators().op("_test::my_op" + schema, [=] (InputType input) -> OutputType {
inputExpectation(std::move(input));
return output;
});
}, input, inputExpectation, output, outputExpectation, schema);
}
private:
static void test_(std::function<c10::RegisterOperators()> registration, InputType input, std::function<void(const InputType&)> inputExpectation, OutputType output, std::function<void(const c10::Stack&)> outputExpectation, const std::string& schema) {
auto registry = registration();
auto op = Dispatcher::singleton().findSchema({"_test::my_op", ""});
ASSERT_TRUE(op.has_value()); // assert schema is registered
auto actualOutput = callOp(*op, input);
outputExpectation(actualOutput);
}
InputType input_;
std::function<void(const InputType&)> inputExpectation_;
OutputType output_;
std::string schema_;
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free