mirror of https://github.com/Qortal/Brooklyn
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
53 lines
2.1 KiB
53 lines
2.1 KiB
# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
|
# SPDX-License-Identifier: MIT |
|
|
|
import tflite_runtime.interpreter as tflite |
|
import numpy as np |
|
import os |
|
|
|
|
|
def run_mock_model(delegate, test_data_folder): |
|
model_path = os.path.join(test_data_folder, 'mock_model.tflite') |
|
interpreter = tflite.Interpreter(model_path=model_path, |
|
experimental_delegates=[delegate]) |
|
interpreter.allocate_tensors() |
|
|
|
# Get input and output tensors. |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details() |
|
|
|
# Test model on random input data. |
|
input_shape = input_details[0]['shape'] |
|
input_data = np.array(np.random.random_sample(input_shape), dtype=np.uint8) |
|
interpreter.set_tensor(input_details[0]['index'], input_data) |
|
|
|
interpreter.invoke() |
|
|
|
def run_inference(test_data_folder, model_filename, inputs, delegates=None): |
|
model_path = os.path.join(test_data_folder, model_filename) |
|
interpreter = tflite.Interpreter(model_path=model_path, |
|
experimental_delegates=delegates) |
|
interpreter.allocate_tensors() |
|
|
|
# Get input and output tensors. |
|
input_details = interpreter.get_input_details() |
|
output_details = interpreter.get_output_details() |
|
|
|
# Set inputs to tensors. |
|
for i in range(len(inputs)): |
|
interpreter.set_tensor(input_details[i]['index'], inputs[i]) |
|
|
|
interpreter.invoke() |
|
|
|
results = [] |
|
for output in output_details: |
|
results.append(interpreter.get_tensor(output['index'])) |
|
|
|
return results |
|
|
|
def compare_outputs(outputs, expected_outputs): |
|
assert len(outputs) == len(expected_outputs), 'Incorrect number of outputs' |
|
for i in range(len(expected_outputs)): |
|
assert outputs[i].shape == expected_outputs[i].shape, 'Incorrect output shape on output#{}'.format(i) |
|
assert outputs[i].dtype == expected_outputs[i].dtype, 'Incorrect output data type on output#{}'.format(i) |
|
assert outputs[i].all() == expected_outputs[i].all(), 'Incorrect output value on output#{}'.format(i) |