mirror of https://github.com/Qortal/Brooklyn
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
346 lines
13 KiB
346 lines
13 KiB
#!/usr/bin/env python3 |
|
# Copyright © 2020 Arm Ltd. All rights reserved. |
|
# Copyright © 2020 NXP and Contributors. All rights reserved. |
|
# SPDX-License-Identifier: MIT |
|
"""Python bindings for Arm NN |
|
|
|
PyArmNN is a python extension for Arm NN SDK providing an interface similar to Arm NN C++ API. |
|
""" |
|
__version__ = None |
|
__arm_ml_version__ = None |
|
|
|
import logging |
|
import os |
|
import sys |
|
import subprocess |
|
from functools import lru_cache |
|
from pathlib import Path |
|
from itertools import chain |
|
|
|
from setuptools import setup |
|
from distutils.core import Extension |
|
from setuptools.command.build_py import build_py |
|
from setuptools.command.build_ext import build_ext |
|
|
|
logger = logging.Logger(__name__) |
|
|
|
DOCLINES = __doc__.split("\n") |
|
LIB_ENV_NAME = "ARMNN_LIB" |
|
INCLUDE_ENV_NAME = "ARMNN_INCLUDE" |
|
|
|
|
|
def check_armnn_version(*args): |
|
pass |
|
|
|
__current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
exec(open(os.path.join(__current_dir, 'src', 'pyarmnn', '_version.py'), encoding="utf-8").read()) |
|
|
|
|
|
class ExtensionPriorityBuilder(build_py): |
|
"""Runs extension builder before other stages. Otherwise generated files are not included to the distribution. |
|
""" |
|
|
|
def run(self): |
|
self.run_command('build_ext') |
|
return super().run() |
|
|
|
|
|
class ArmnnVersionCheckerExtBuilder(build_ext): |
|
"""Builds an extension (i.e. wrapper). Additionally checks for version. |
|
""" |
|
|
|
def __init__(self, dist): |
|
super().__init__(dist) |
|
self.failed_ext = [] |
|
|
|
def build_extension(self, ext): |
|
if ext.optional: |
|
try: |
|
super().build_extension(ext) |
|
except Exception as err: |
|
self.failed_ext.append(ext) |
|
logger.warning('Failed to build extension %s. \n %s', ext.name, str(err)) |
|
else: |
|
super().build_extension(ext) |
|
if ext.name == 'pyarmnn._generated._pyarmnn_version': |
|
sys.path.append(os.path.abspath(os.path.join(self.build_lib, str(Path(ext._file_name).parent)))) |
|
from _pyarmnn_version import GetVersion |
|
check_armnn_version(GetVersion(), __arm_ml_version__) |
|
|
|
def copy_extensions_to_source(self): |
|
|
|
for ext in self.failed_ext: |
|
self.extensions.remove(ext) |
|
super().copy_extensions_to_source() |
|
|
|
|
|
def linux_gcc_name(): |
|
"""Returns the name of the `gcc` compiler. Might happen that we are cross-compiling and the |
|
compiler has a longer name. |
|
|
|
Args: |
|
None |
|
|
|
Returns: |
|
str: Name of the `gcc` compiler or None |
|
""" |
|
cc_env = os.getenv('CC') |
|
if cc_env is not None: |
|
if subprocess.Popen([cc_env, "--version"], stdout=subprocess.DEVNULL): |
|
return cc_env |
|
return "gcc" if subprocess.Popen(["gcc", "--version"], stdout=subprocess.DEVNULL) else None |
|
|
|
|
|
def linux_gcc_lib_search(gcc_compiler_name: str = linux_gcc_name()): |
|
"""Calls the `gcc` to get linker default system paths. |
|
|
|
Args: |
|
gcc_compiler_name(str): Name of the GCC compiler |
|
|
|
Returns: |
|
list: A list of paths. |
|
|
|
Raises: |
|
RuntimeError: If unable to find GCC. |
|
""" |
|
if gcc_compiler_name is None: |
|
raise RuntimeError("Unable to find gcc compiler") |
|
cmd1 = subprocess.Popen([gcc_compiler_name, "--print-search-dirs"], stdout=subprocess.PIPE) |
|
cmd2 = subprocess.Popen(["grep", "libraries"], stdin=cmd1.stdout, |
|
stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) |
|
cmd1.stdout.close() |
|
out, _ = cmd2.communicate() |
|
out = out.decode("utf-8").split('=') |
|
return tuple(out[1].split(':')) if len(out) > 0 else None |
|
|
|
|
|
def find_includes(armnn_include_env: str = INCLUDE_ENV_NAME): |
|
"""Searches for ArmNN includes. |
|
|
|
Args: |
|
armnn_include_env(str): Environmental variable to use as path. |
|
|
|
Returns: |
|
list: A list of paths to include. |
|
""" |
|
|
|
# split multiple paths |
|
global armnn_include_path |
|
armnn_include_path_raw = os.getenv(armnn_include_env) |
|
if not armnn_include_path_raw == None: |
|
armnn_include_path = armnn_include_path_raw.split(":") |
|
|
|
# validate input paths |
|
armnn_include_path_result = [] |
|
for path in armnn_include_path: |
|
if path is not None and os.path.exists(path): |
|
armnn_include_path_result = armnn_include_path_result + [path] |
|
|
|
|
|
# if none exist revert to default |
|
if len(armnn_include_path_result) == 0: |
|
armnn_include_path_result = ['/usr/local/include', '/usr/include'] |
|
return armnn_include_path_result |
|
|
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def find_armnn(lib_name: str, |
|
optional: bool = False, |
|
armnn_libs_env: str = LIB_ENV_NAME, |
|
default_lib_search: tuple = linux_gcc_lib_search()): |
|
"""Searches for ArmNN installation on the local machine. |
|
|
|
Args: |
|
lib_name(str): Lib name to find. |
|
optional(bool): Do not fail if optional. Default is False - fail if library was not found. |
|
armnn_libs_env(str): Custom environment variable pointing to ArmNN libraries location, default is 'ARMNN_LIBS' |
|
default_lib_search(tuple): list of paths to search for ArmNN if not found within path provided by 'ARMNN_LIBS' |
|
env variable |
|
Returns: |
|
tuple: Contains name of the armnn libs, paths to the libs. |
|
|
|
Raises: |
|
RuntimeError: If armnn libs are not found. |
|
""" |
|
armnn_lib_path = os.getenv(armnn_libs_env) |
|
lib_search = [armnn_lib_path] if armnn_lib_path is not None else default_lib_search |
|
armnn_libs = dict(map(lambda path: (':{}'.format(path.name), path), |
|
chain.from_iterable(map(lambda lib_path: Path(lib_path).glob(lib_name), |
|
lib_search)))) |
|
if not optional and len(armnn_libs) == 0: |
|
raise RuntimeError("""ArmNN library {} was not found in {}. Please install ArmNN to one of the standard |
|
locations or set correct ARMNN_INCLUDE and ARMNN_LIB env variables.""".format(lib_name, |
|
lib_search)) |
|
if optional and len(armnn_libs) == 0: |
|
logger.warning("""Optional parser library %s was not found in %s and will not be installed.""", lib_name, |
|
lib_search) |
|
|
|
# gives back tuple of names of the libs, set of unique libs locations and includes. |
|
return list(armnn_libs.keys()), list(set( |
|
map(lambda path: str(path.absolute().parent), armnn_libs.values()))) |
|
|
|
|
|
class LazyArmnnFinderExtension(Extension): |
|
"""Derived from `Extension` this class adds ArmNN libraries search on the user's machine. |
|
SWIG options and compilation flags are updated with relevant ArmNN libraries files locations (-L) and headers (-I). |
|
|
|
Search for ArmNN is executed only when attributes include_dirs, library_dirs, runtime_library_dirs, libraries or |
|
swig_opts are queried. |
|
|
|
""" |
|
|
|
def __init__(self, name, sources, armnn_libs, include_dirs=None, define_macros=None, undef_macros=None, |
|
library_dirs=None, |
|
libraries=None, runtime_library_dirs=None, extra_objects=None, extra_compile_args=None, |
|
extra_link_args=None, export_symbols=None, language=None, optional=None, **kw): |
|
self._include_dirs = None |
|
self._library_dirs = None |
|
self._runtime_library_dirs = None |
|
self._armnn_libs = armnn_libs |
|
self._optional = False if optional is None else optional |
|
|
|
super().__init__(name=name, sources=sources, include_dirs=include_dirs, define_macros=define_macros, |
|
undef_macros=undef_macros, library_dirs=library_dirs, libraries=libraries, |
|
runtime_library_dirs=runtime_library_dirs, extra_objects=extra_objects, |
|
extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, |
|
export_symbols=export_symbols, language=language, optional=optional, **kw) |
|
|
|
@property |
|
def include_dirs(self): |
|
return self._include_dirs + find_includes() |
|
|
|
@include_dirs.setter |
|
def include_dirs(self, include_dirs): |
|
self._include_dirs = include_dirs |
|
|
|
@property |
|
def library_dirs(self): |
|
library_dirs = self._library_dirs |
|
for lib in self._armnn_libs: |
|
_, lib_path = find_armnn(lib, self._optional) |
|
library_dirs = library_dirs + lib_path |
|
|
|
return library_dirs |
|
|
|
@library_dirs.setter |
|
def library_dirs(self, library_dirs): |
|
self._library_dirs = library_dirs |
|
|
|
@property |
|
def runtime_library_dirs(self): |
|
library_dirs = self._runtime_library_dirs |
|
for lib in self._armnn_libs: |
|
_, lib_path = find_armnn(lib, self._optional) |
|
library_dirs = library_dirs + lib_path |
|
|
|
return library_dirs |
|
|
|
@runtime_library_dirs.setter |
|
def runtime_library_dirs(self, runtime_library_dirs): |
|
self._runtime_library_dirs = runtime_library_dirs |
|
|
|
@property |
|
def libraries(self): |
|
libraries = self._libraries |
|
for lib in self._armnn_libs: |
|
lib_names, _ = find_armnn(lib, self._optional) |
|
libraries = libraries + lib_names |
|
|
|
return libraries |
|
|
|
@libraries.setter |
|
def libraries(self, libraries): |
|
self._libraries = libraries |
|
|
|
def __eq__(self, other): |
|
return self.__class__ == other.__class__ and self.name == other.name |
|
|
|
def __ne__(self, other): |
|
return not self.__eq__(other) |
|
|
|
def __hash__(self): |
|
return self.name.__hash__() |
|
|
|
|
|
if __name__ == '__main__': |
|
# mandatory extensions |
|
pyarmnn_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn', |
|
sources=['src/pyarmnn/_generated/armnn_wrap.cpp'], |
|
extra_compile_args=['-std=c++14'], |
|
language='c++', |
|
armnn_libs=['libarmnn.so'], |
|
optional=False |
|
) |
|
pyarmnn_v_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_version', |
|
sources=['src/pyarmnn/_generated/armnn_version_wrap.cpp'], |
|
extra_compile_args=['-std=c++14'], |
|
language='c++', |
|
armnn_libs=['libarmnn.so'], |
|
optional=False |
|
) |
|
extensions_to_build = [pyarmnn_v_module, pyarmnn_module] |
|
|
|
|
|
# optional extensions |
|
def add_parsers_ext(name: str, ext_list: list): |
|
pyarmnn_optional_module = LazyArmnnFinderExtension('pyarmnn._generated._pyarmnn_{}'.format(name.lower()), |
|
sources=['src/pyarmnn/_generated/armnn_{}_wrap.cpp'.format( |
|
name.lower())], |
|
extra_compile_args=['-std=c++14'], |
|
language='c++', |
|
armnn_libs=['libarmnn.so', 'libarmnn{}.so'.format(name)], |
|
optional=True |
|
) |
|
ext_list.append(pyarmnn_optional_module) |
|
|
|
|
|
add_parsers_ext('OnnxParser', extensions_to_build) |
|
add_parsers_ext('TfLiteParser', extensions_to_build) |
|
add_parsers_ext('Deserializer', extensions_to_build) |
|
|
|
setup( |
|
name='pyarmnn', |
|
version=__version__, |
|
author='Arm Ltd, NXP Semiconductors', |
|
author_email='[email protected]', |
|
description=DOCLINES[0], |
|
long_description="\n".join(DOCLINES[2:]), |
|
url='https://mlplatform.org/', |
|
license='MIT', |
|
keywords='armnn neural network machine learning', |
|
classifiers=[ |
|
'Development Status :: 3 - Alpha', |
|
'Intended Audience :: Developers', |
|
'Intended Audience :: Education', |
|
'Intended Audience :: Science/Research', |
|
'License :: OSI Approved :: MIT License', |
|
'Programming Language :: Python :: 3', |
|
'Programming Language :: Python :: 3 :: Only', |
|
'Programming Language :: Python :: 3.6', |
|
'Programming Language :: Python :: 3.7', |
|
'Programming Language :: Python :: 3.8', |
|
'Topic :: Scientific/Engineering', |
|
'Topic :: Scientific/Engineering :: Artificial Intelligence', |
|
'Topic :: Software Development', |
|
'Topic :: Software Development :: Libraries', |
|
'Topic :: Software Development :: Libraries :: Python Modules', |
|
], |
|
package_dir={'': 'src'}, |
|
packages=[ |
|
'pyarmnn', |
|
'pyarmnn._generated', |
|
'pyarmnn._quantization', |
|
'pyarmnn._tensor', |
|
'pyarmnn._utilities' |
|
], |
|
data_files=[('', ['LICENSE'])], |
|
python_requires='>=3.5', |
|
install_requires=['numpy'], |
|
cmdclass={ |
|
'build_py': ExtensionPriorityBuilder, |
|
'build_ext': ArmnnVersionCheckerExtBuilder |
|
}, |
|
ext_modules=extensions_to_build |
|
)
|
|
|