mirror of https://github.com/Qortal/Brooklyn
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
198 lines
7.1 KiB
198 lines
7.1 KiB
// |
|
// Copyright © 2017 Arm Ltd. All rights reserved. |
|
// SPDX-License-Identifier: MIT |
|
// |
|
#pragma once |
|
|
|
#include "InferenceTest.hpp" |
|
#include "DeepSpeechV1Database.hpp" |
|
|
|
#include <armnn/utility/Assert.hpp> |
|
#include <armnn/utility/IgnoreUnused.hpp> |
|
#include <armnnUtils/FloatingPointComparison.hpp> |
|
|
|
#include <vector> |
|
|
|
namespace |
|
{ |
|
|
|
template<typename Model> |
|
class DeepSpeechV1TestCase : public InferenceModelTestCase<Model> |
|
{ |
|
public: |
|
DeepSpeechV1TestCase(Model& model, |
|
unsigned int testCaseId, |
|
const DeepSpeechV1TestCaseData& testCaseData) |
|
: InferenceModelTestCase<Model>(model, |
|
testCaseId, |
|
{ testCaseData.m_InputData.m_InputSeq, |
|
testCaseData.m_InputData.m_StateH, |
|
testCaseData.m_InputData.m_StateC}, |
|
{ k_OutputSize1, k_OutputSize2, k_OutputSize3 }) |
|
, m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH, |
|
testCaseData.m_ExpectedOutputData.m_StateC}) |
|
{} |
|
|
|
TestCaseResult ProcessResult(const InferenceTestOptions& options) override |
|
{ |
|
armnn::IgnoreUnused(options); |
|
const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits |
|
ARMNN_ASSERT(output1.size() == k_OutputSize1); |
|
|
|
const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c |
|
ARMNN_ASSERT(output2.size() == k_OutputSize2); |
|
|
|
const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h |
|
ARMNN_ASSERT(output3.size() == k_OutputSize3); |
|
|
|
// Check each output to see whether it is the expected value |
|
for (unsigned int j = 0u; j < output1.size(); j++) |
|
{ |
|
if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j])) |
|
{ |
|
ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() << |
|
" is incorrect at" << j; |
|
return TestCaseResult::Failed; |
|
} |
|
} |
|
|
|
for (unsigned int j = 0u; j < output2.size(); j++) |
|
{ |
|
if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j])) |
|
{ |
|
ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() << |
|
" is incorrect"; |
|
return TestCaseResult::Failed; |
|
} |
|
} |
|
|
|
for (unsigned int j = 0u; j < output3.size(); j++) |
|
{ |
|
if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j])) |
|
{ |
|
ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() << |
|
" is incorrect"; |
|
return TestCaseResult::Failed; |
|
} |
|
} |
|
return TestCaseResult::Ok; |
|
} |
|
|
|
private: |
|
|
|
static constexpr unsigned int k_OutputSize1 = 464u; |
|
static constexpr unsigned int k_OutputSize2 = 2048u; |
|
static constexpr unsigned int k_OutputSize3 = 2048u; |
|
|
|
LstmInput m_ExpectedOutputs; |
|
}; |
|
|
|
template <typename Model> |
|
class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider |
|
{ |
|
public: |
|
template <typename TConstructModelCallable> |
|
explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel) |
|
: m_ConstructModel(constructModel) |
|
{} |
|
|
|
virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override |
|
{ |
|
options |
|
.allow_unrecognised_options() |
|
.add_options() |
|
("s,input-seq-dir", "Path to directory containing test data for m_InputSeq", |
|
cxxopts::value<std::string>(m_InputSeqDir)) |
|
("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH", |
|
cxxopts::value<std::string>(m_PrevStateHDir)) |
|
("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC", |
|
cxxopts::value<std::string>(m_PrevStateCDir)) |
|
("l,logits-dir", "Path to directory containing test data for m_Logits", |
|
cxxopts::value<std::string>(m_LogitsDir)) |
|
("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH", |
|
cxxopts::value<std::string>(m_NewStateHDir)) |
|
("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC", |
|
cxxopts::value<std::string>(m_NewStateCDir)); |
|
|
|
required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir", |
|
"new-state-h-dir", "new-state-c-dir"}); |
|
|
|
Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required); |
|
} |
|
|
|
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override |
|
{ |
|
if (!ValidateDirectory(m_InputSeqDir)) |
|
{ |
|
return false; |
|
} |
|
|
|
if (!ValidateDirectory(m_PrevStateCDir)) |
|
{ |
|
return false; |
|
} |
|
|
|
if (!ValidateDirectory(m_PrevStateHDir)) |
|
{ |
|
return false; |
|
} |
|
|
|
if (!ValidateDirectory(m_LogitsDir)) |
|
{ |
|
return false; |
|
} |
|
|
|
if (!ValidateDirectory(m_NewStateCDir)) |
|
{ |
|
return false; |
|
} |
|
|
|
if (!ValidateDirectory(m_NewStateHDir)) |
|
{ |
|
return false; |
|
} |
|
|
|
m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions); |
|
if (!m_Model) |
|
{ |
|
return false; |
|
} |
|
m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(), |
|
m_PrevStateCDir.c_str(), m_LogitsDir.c_str(), |
|
m_NewStateHDir.c_str(), m_NewStateCDir.c_str()); |
|
if (!m_Database) |
|
{ |
|
return false; |
|
} |
|
|
|
return true; |
|
} |
|
|
|
std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override |
|
{ |
|
std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId); |
|
if (!testCaseData) |
|
{ |
|
return nullptr; |
|
} |
|
|
|
return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData); |
|
} |
|
|
|
private: |
|
typename Model::CommandLineOptions m_ModelCommandLineOptions; |
|
std::function<std::unique_ptr<Model>(const InferenceTestOptions&, |
|
typename Model::CommandLineOptions)> m_ConstructModel; |
|
std::unique_ptr<Model> m_Model; |
|
|
|
std::string m_InputSeqDir; |
|
std::string m_PrevStateCDir; |
|
std::string m_PrevStateHDir; |
|
std::string m_LogitsDir; |
|
std::string m_NewStateCDir; |
|
std::string m_NewStateHDir; |
|
|
|
std::unique_ptr<DeepSpeechV1Database> m_Database; |
|
}; |
|
|
|
} // anonymous namespace
|
|
|