mirror of https://github.com/Qortal/Brooklyn
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
252 lines
7.6 KiB
252 lines
7.6 KiB
// |
|
// Copyright © 2017 Arm Ltd. All rights reserved. |
|
// SPDX-License-Identifier: MIT |
|
// |
|
#include "InferenceTest.hpp" |
|
|
|
#include <armnn/utility/Assert.hpp> |
|
#include <armnnUtils/Filesystem.hpp> |
|
|
|
#include "../src/armnn/Profiling.hpp" |
|
#include <cxxopts/cxxopts.hpp> |
|
|
|
#include <fstream> |
|
#include <iostream> |
|
#include <iomanip> |
|
#include <array> |
|
|
|
using namespace std; |
|
using namespace std::chrono; |
|
using namespace armnn::test; |
|
|
|
namespace armnn |
|
{ |
|
namespace test |
|
{ |
|
/// Parse the command line of an ArmNN (or referencetests) inference test program. |
|
/// \return false if any error occurred during options processing, otherwise true |
|
bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider, |
|
InferenceTestOptions& outParams) |
|
{ |
|
cxxopts::Options options("InferenceTest", "Inference iteration parameters"); |
|
|
|
try |
|
{ |
|
// Adds generic options needed for all inference tests. |
|
options |
|
.allow_unrecognised_options() |
|
.add_options() |
|
("h,help", "Display help messages") |
|
("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.", |
|
cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0")) |
|
("inference-times-file", |
|
"If non-empty, each individual inference time will be recorded and output to this file", |
|
cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value("")) |
|
("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.", |
|
cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0")); |
|
|
|
std::vector<std::string> required; //to be passed as reference to derived inference tests |
|
|
|
// Adds options specific to the ITestCaseProvider. |
|
testCaseProvider.AddCommandLineOptions(options, required); |
|
|
|
auto result = options.parse(argc, argv); |
|
|
|
if (result.count("help")) |
|
{ |
|
std::cout << options.help() << std::endl; |
|
return false; |
|
} |
|
|
|
CheckRequiredOptions(result, required); |
|
|
|
} |
|
catch (const cxxopts::OptionException& e) |
|
{ |
|
std::cerr << e.what() << std::endl << options.help() << std::endl; |
|
return false; |
|
} |
|
catch (const std::exception& e) |
|
{ |
|
ARMNN_ASSERT_MSG(false, "Caught unexpected exception"); |
|
std::cerr << "Fatal internal error: " << e.what() << std::endl; |
|
return false; |
|
} |
|
|
|
if (!testCaseProvider.ProcessCommandLineOptions(outParams)) |
|
{ |
|
return false; |
|
} |
|
|
|
return true; |
|
} |
|
|
|
bool ValidateDirectory(std::string& dir) |
|
{ |
|
if (dir.empty()) |
|
{ |
|
std::cerr << "No directory specified" << std::endl; |
|
return false; |
|
} |
|
|
|
if (dir[dir.length() - 1] != '/') |
|
{ |
|
dir += "/"; |
|
} |
|
|
|
if (!fs::exists(dir)) |
|
{ |
|
std::cerr << "Given directory " << dir << " does not exist" << std::endl; |
|
return false; |
|
} |
|
|
|
if (!fs::is_directory(dir)) |
|
{ |
|
std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl; |
|
return false; |
|
} |
|
|
|
return true; |
|
} |
|
|
|
bool InferenceTest(const InferenceTestOptions& params, |
|
const std::vector<unsigned int>& defaultTestCaseIds, |
|
IInferenceTestCaseProvider& testCaseProvider) |
|
{ |
|
#if !defined (NDEBUG) |
|
if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn. |
|
{ |
|
ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate."; |
|
} |
|
#endif |
|
|
|
double totalTime = 0; |
|
unsigned int nbProcessed = 0; |
|
bool success = true; |
|
|
|
// Opens the file to write inference times too, if needed. |
|
ofstream inferenceTimesFile; |
|
const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty(); |
|
if (recordInferenceTimes) |
|
{ |
|
inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out); |
|
if (!inferenceTimesFile.good()) |
|
{ |
|
ARMNN_LOG(error) << "Failed to open inference times file for writing: " |
|
<< params.m_InferenceTimesFile; |
|
return false; |
|
} |
|
} |
|
|
|
// Create a profiler and register it for the current thread. |
|
std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>(); |
|
ProfilerManager::GetInstance().RegisterProfiler(profiler.get()); |
|
|
|
// Enable profiling if requested. |
|
profiler->EnableProfiling(params.m_EnableProfiling); |
|
|
|
// Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer |
|
std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0); |
|
if (warmupTestCase == nullptr) |
|
{ |
|
ARMNN_LOG(error) << "Failed to load test case"; |
|
return false; |
|
} |
|
|
|
try |
|
{ |
|
warmupTestCase->Run(); |
|
} |
|
catch (const TestFrameworkException& testError) |
|
{ |
|
ARMNN_LOG(error) << testError.what(); |
|
return false; |
|
} |
|
|
|
const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount |
|
: static_cast<unsigned int>(defaultTestCaseIds.size()); |
|
|
|
for (; nbProcessed < nbTotalToProcess; nbProcessed++) |
|
{ |
|
const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed]; |
|
std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId); |
|
|
|
if (testCase == nullptr) |
|
{ |
|
ARMNN_LOG(error) << "Failed to load test case"; |
|
return false; |
|
} |
|
|
|
time_point<high_resolution_clock> predictStart; |
|
time_point<high_resolution_clock> predictEnd; |
|
|
|
TestCaseResult result = TestCaseResult::Ok; |
|
|
|
try |
|
{ |
|
predictStart = high_resolution_clock::now(); |
|
|
|
testCase->Run(); |
|
|
|
predictEnd = high_resolution_clock::now(); |
|
|
|
// duration<double> will convert the time difference into seconds as a double by default. |
|
double timeTakenS = duration<double>(predictEnd - predictStart).count(); |
|
totalTime += timeTakenS; |
|
|
|
// Outputss inference times, if needed. |
|
if (recordInferenceTimes) |
|
{ |
|
inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl; |
|
} |
|
|
|
result = testCase->ProcessResult(params); |
|
|
|
} |
|
catch (const TestFrameworkException& testError) |
|
{ |
|
ARMNN_LOG(error) << testError.what(); |
|
result = TestCaseResult::Abort; |
|
} |
|
|
|
switch (result) |
|
{ |
|
case TestCaseResult::Ok: |
|
break; |
|
case TestCaseResult::Abort: |
|
return false; |
|
case TestCaseResult::Failed: |
|
// This test failed so we will fail the entire program eventually, but keep going for now. |
|
success = false; |
|
break; |
|
default: |
|
ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult"); |
|
return false; |
|
} |
|
} |
|
|
|
const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f; |
|
|
|
ARMNN_LOG(info) << std::fixed << std::setprecision(3) << |
|
"Total time for " << nbProcessed << " test cases: " << totalTime << " seconds"; |
|
ARMNN_LOG(info) << std::fixed << std::setprecision(3) << |
|
"Average time per test case: " << averageTimePerTestCaseMs << " ms"; |
|
|
|
// if profiling is enabled print out the results |
|
if (profiler && profiler->IsProfilingEnabled()) |
|
{ |
|
profiler->Print(std::cout); |
|
} |
|
|
|
if (!success) |
|
{ |
|
ARMNN_LOG(error) << "One or more test cases failed"; |
|
return false; |
|
} |
|
|
|
return testCaseProvider.OnInferenceTestFinished(); |
|
} |
|
|
|
} // namespace test |
|
|
|
} // namespace armnn
|
|
|