1#!/usr/bin/env python3
2#
3# SPDX-FileCopyrightText: Copyright 2010-2020, 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the License); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an AS IS BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19import os
20import re
21import sys
22import json
23import copy
24import glob
25import time
26import queue
27import shutil
28import serial
29import argparse
30import threading
31import subprocess
32
33from os import path
34from termcolor import colored
35
36OUTPUT = "Output/"
37BASE_PATH = "../../"
38UNITY_PATH = "Unity/"
39UNITY_BASE = BASE_PATH + UNITY_PATH
40UNITY_SRC = UNITY_BASE + "src/"
41
42
43def parse_args():
44    parser = argparse.ArgumentParser(description="Run CMSIS-NN unit tests.",
45                                     epilog="Runs on all connected HW supported by Mbed.")
46    parser.add_argument('--testdir', type=str, default='TESTRUN', help="prefix of output dir name")
47    parser.add_argument('-s', '--specific-test', type=str, default=None, help="Run a specific test, e.g."
48                        " -s TestCases/test_arm_avgpool_s8 (also this form will work: -s test_arm_avgpool_s8)."
49                        " So basically the different options can be listed with:"
50                        " ls -d TestCases/test_* -1")
51    parser.add_argument('-c', '--compiler', type=str, default='GCC_ARM', choices=['GCC_ARM', 'ARMC6'])
52    parser.add_argument('--download-and-generate-test-runners', dest='download_and_generate', action='store_true',
53                        help="Just download Unity and generate test runners if needed")
54
55    required_arguments = parser.add_argument_group('required named arguments')
56    required_arguments.add_argument('-p', '--path-to-cmsis', help='Path to CMSIS project for CMSIS/Core.',
57                                    required=True)
58
59    args = parser.parse_args()
60    return args
61
62
63def error_handler(code, text=None):
64    print("Error: {}".format(text))
65    sys.exit(code)
66
67
68def detect_targets(targets):
69    process = subprocess.Popen(['mbedls'],
70                               stdout=subprocess.PIPE,
71                               stderr=subprocess.PIPE,
72                               universal_newlines=True)
73    print(process.stdout.readline().strip())
74    while True:
75        line = process.stdout.readline()
76        print(line.strip())
77        if not line:
78            break
79        if re.search(r"^\| ", line):
80            words = (line.split('| '))
81            target = {"model": words[1].strip(),
82                      "name": words[2].strip()[:-1].replace('[', '_'),
83                      "port": words[4].strip(),
84                      "tid": words[5].strip()}  # Target id can be used to filter out targets
85            targets.append(target)
86    return_code = process.poll()
87    if return_code != 0:
88        error_handler(return_code, 'RETURN CODE {}'.format(process.stderr.read()))
89
90
91def run_command(command, error_msg=None, die=True):
92    # TODO handle error:
93    # cp: error writing '/media/mannil01/NODE_F411RE/TESTRUN_NUCLEO_F411RE_GCC_ARM.bin': No space left on device
94    # https://os.mbed.com/questions/59636/STM-Nucleo-No-space-left-on-device-when-/
95
96    # print(command)
97    command_list = command.split(' ')
98    process = subprocess.run(command_list)
99    if die and process.returncode != 0:
100        error_handler(process.returncode, error_msg)
101    return process.returncode
102
103
104def detect_architecture(target_name, target_json):
105    arch = None
106
107    try:
108        with open(target_json, "r") as read_file:
109            data = json.load(read_file)
110
111            if 'core' in data[target_name]:
112                core = data[target_name]['core']
113            elif 'inherits' in data[target_name]:
114                target_inherits = data[target_name]['inherits'][0]
115                core = data[target_inherits]['core']
116            else:
117                raise Exception("Cannot detect architecture")
118
119            if core:
120                arch = core[:9]
121                if core[:8] == 'Cortex-M':
122                    return arch
123            error_handler(168, 'Unsupported target: {} with architecture: {}'.format(
124                target_name, arch))
125    except Exception as e:
126        error_handler(167, e)
127
128    return arch
129
130
131def test_target(target, args, main_test):
132    result = 3
133    compiler = args.compiler
134    cmsis_path = args.path_to_cmsis
135    cmsis_nn_path = "../../../../"
136    target_name = target['name']
137    target_model = target['model']
138    unittestframework = 'UNITY_UNITTEST'
139
140    dir_name = OUTPUT + args.testdir + '_' + unittestframework + '_' + target_name + '_' + compiler
141
142    os.makedirs(dir_name, exist_ok=True)
143    start_dir = os.getcwd()
144    os.chdir(dir_name)
145
146    cmsis_core_include_path = cmsis_path + '/CMSIS/Core/Include/'
147    relative_cmsis_core_include_path = os.path.relpath(cmsis_core_include_path, '.')
148
149    try:
150        target_json = 'mbed-os/targets/targets.json'
151        mbed_path = BASE_PATH + 'Mbed/'
152
153        if not path.exists("mbed-os.lib"):
154            print("Initializing mbed in {}".format(os.getcwd()))
155            shutil.copyfile(mbed_path + 'mbed-os.lib', 'mbed-os.lib')
156            shutil.copyfile(mbed_path + 'mbed_app.json', 'mbed_app.json')
157            run_command('mbed config root .')
158            run_command('mbed deploy')
159        arch = detect_architecture(target_model, target_json)
160
161        print("----------------------------------------------------------------")
162        print("Running {} on {} target: {} with compiler: {} in directory: {} test: {}\n".format(
163            unittestframework, arch, target_name, compiler, os.getcwd(), main_test))
164
165        die = False
166        flash_error_msg = 'failed to flash'
167        mbed_command = "compile"
168        test = ''
169        additional_options = ' --source ' + BASE_PATH + main_test + \
170                             ' --source ' + UNITY_SRC + \
171                             ' --profile ' + mbed_path + 'release.json' + \
172                             ' -f'
173
174        result = run_command("mbed {} -v -m ".format(mbed_command) + target_model + ' -t ' + compiler +
175                             test +
176                             ' --source .'
177                             ' --source ' + BASE_PATH + 'TestCases/Utils/'
178                             ' --source ' + cmsis_nn_path + 'Include/'
179                             ' --source ' + relative_cmsis_core_include_path +
180                             ' --source ' + cmsis_nn_path + 'Source/ConvolutionFunctions/'
181                             ' --source ' + cmsis_nn_path + 'Source/PoolingFunctions/'
182                             ' --source ' + cmsis_nn_path + 'Source/NNSupportFunctions/'
183                             ' --source ' + cmsis_nn_path + 'Source/FullyConnectedFunctions/'
184                             ' --source ' + cmsis_nn_path + 'Source/SoftmaxFunctions/'
185                             ' --source ' + cmsis_nn_path + 'Source/SVDFunctions/'
186                             ' --source ' + cmsis_nn_path + 'Source/BasicMathFunctions/'
187                             + additional_options,
188                             flash_error_msg, die=die)
189
190    except Exception as e:
191        error_handler(166, e)
192
193    os.chdir(start_dir)
194    return result
195
196
197def read_serial_port(ser, inputQueue, stop):
198    while True:
199        if stop():
200            break
201        line = ser.readline()
202        inputQueue.put(line.decode('latin-1').strip())
203
204
205def test_target_with_unity(target, args, main_test):
206    port = target['port']
207    stop_thread = False
208    baudrate = 9600
209    timeout = 30
210    inputQueue = queue.Queue()
211    tests = copy.deepcopy(target["tests"])
212
213    try:
214        ser = serial.Serial(port, baudrate, timeout=timeout)
215    except Exception as e:
216        error_handler(169, "serial exception: {}".format(e))
217
218    # Clear read buffer
219    time.sleep(0.1)  # Workaround in response to: open() returns before port is ready
220    ser.reset_input_buffer()
221
222    serial_thread = threading.Thread(target=read_serial_port, args=(ser, inputQueue, lambda: stop_thread), daemon=True)
223    serial_thread.start()
224
225    test_target(target, args, main_test)
226
227    start_time = time.time()
228    while time.time() < start_time + timeout:
229        if inputQueue.qsize() > 0:
230            str_line = inputQueue.get()
231            print(str_line)
232            test = None
233            try:
234                test = str_line.split(':')[2]
235                test_result = ':'.join(str_line.split(':')[2:4])
236            except IndexError:
237                pass
238            if test in tests:
239                tests.remove(test)
240                target[test]["tested"] = True
241                if test_result == test + ':PASS':
242                    target[test]["pass"] = True
243            if len(tests) == 0:
244                break
245
246    stop_thread = True
247    serial_thread.join()
248    ser.close()
249
250
251def print_summary(targets):
252    """
253    Return 0 if all test passed
254    Return 1 if all test completed but one or more failed
255    Return 2 if one or more tests did not complete or was not detected
256    """
257    passed = 0
258    failed = 0
259    tested = 0
260    expected = 0
261    return_code = 3
262    verdict_pass = colored('[ PASSED ]', 'green')
263    verdict_fail = colored('[ FAILED ]', 'red')
264    verdict_error = colored('[ ERROR ]', 'red')
265
266    print("-----------------------------------------------------------------------------------------------------------")
267
268    # Find all passed and failed
269    for target in targets:
270        for test in target["tests"]:
271            expected += 1
272            if target[test]["tested"]:
273                tested += 1
274            if target[test]["pass"]:
275                passed += 1
276            else:
277                failed += 1
278
279    if tested != expected:
280        print("{} Not all tests found".format(verdict_error))
281        print("{} Expected: {} Actual: {}".format(verdict_error, expected, tested))
282        return_code = 2
283    elif tested == passed:
284        return_code = 0
285    else:
286        return_code = 1
287
288    # print all test cases
289    sorted_tc = []
290    for target in targets:
291        for test in target["tests"]:
292            if not target[test]["tested"]:
293                tc_verdict = verdict_error
294            elif target[test]["pass"]:
295                tc_verdict = verdict_pass
296            else:
297                tc_verdict = verdict_fail
298            sorted_tc.append("{} {}: {}".format(tc_verdict, target["name"], test))
299    sorted_tc.sort()
300    for tc in sorted_tc:
301        print(tc)
302
303    total = 0
304    if (passed > 0):
305        total = passed / expected
306    if (total == 1.0):
307        verdict = verdict_pass
308    else:
309        verdict = verdict_fail
310    print("{} Summary: {} tests in total passed on {} target(s) ({})".
311          format(verdict, passed, len(targets), ', '.join([t['name'] for t in targets])))
312    print("{} {:.0f}% tests passed, {} tests failed out of {}".format(verdict, total*100, failed, expected))
313
314    return return_code
315
316
317def test_targets(args):
318    """
319    Return 0 if successful
320    Return 3 if no targets are detected
321    Return 4 if no tests are found
322    """
323    result = 0
324    targets = []
325    main_tests = []
326
327    if not args.download_and_generate:
328        detect_targets(targets)
329        if len(targets) == 0:
330            print("No targets detected!")
331            return 3
332
333    download_unity()
334
335    if not parse_tests(targets, main_tests, args.specific_test):
336        print("No tests found?!")
337        return 4
338
339    if args.download_and_generate:
340        return result
341
342    for target in targets:
343        for tst in main_tests:
344            test_target_with_unity(target, args, tst)
345
346    result = print_summary(targets)
347
348    return result
349
350
351def download_unity(force=False):
352    unity_dir = UNITY_PATH
353    unity_src = unity_dir+"src/"
354    process = subprocess.run(['mktemp'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
355    download_dir = process.stdout.strip()
356    run_command("rm -f {}".format(download_dir))
357    download_dir += '/'
358
359    # Check if already downloaded
360    if not force and path.isdir(unity_dir) and path.isfile(unity_src+"unity.c") and path.isfile(unity_src+"unity.h"):
361        return
362
363    if path.isdir(download_dir):
364        shutil.rmtree(download_dir)
365    if path.isdir(unity_dir):
366        shutil.rmtree(unity_dir)
367    os.mkdir(unity_dir)
368    os.makedirs(download_dir, exist_ok=False)
369    current_dir = os.getcwd()
370    os.chdir(download_dir)
371    process = subprocess.Popen('curl -LJ https://api.github.com/repos/ThrowTheSwitch/Unity/tarball/v2.5.0 --output unity_tarball.tar.gz'.split(),
372                               stdout=subprocess.PIPE,
373                               stderr=subprocess.PIPE,
374                               universal_newlines=True)
375    for line in process.stderr:
376        print(line.strip())
377    print()
378    for line in process.stdout:
379        pass
380    if not line:
381        error_handler(171)
382    downloaded_file = download_dir + "unity_tarball.tar.gz"
383    os.chdir(current_dir)
384    try:
385        filename_base = downloaded_file.split('-')[0]
386    except IndexError as e:
387        error_handler(174, e)
388    if not filename_base:
389        error_handler(175)
390    run_command("tar xzf "+downloaded_file+" -C "+unity_dir+" --strip-components=1")
391    os.chdir(current_dir)
392
393    # Cleanup
394    shutil.rmtree(download_dir)
395
396
397def parse_generated_test_runner(test_runner):
398    parsed_functions = ['setUp', 'tearDown', 'resetTest', 'verifyTest']
399
400    def is_func_to_parse(func):
401        for f in parsed_functions:
402            if f in func:
403                return True
404        return False
405
406    with open(test_runner, "r") as f:
407        lines = f.readlines()
408    with open(test_runner, "w") as f:
409        for line in lines:
410            sline = line.strip('\n')
411            if not re.search(r"\(void\);", sline):
412                f.write(line)
413            else:
414                if not is_func_to_parse(sline):
415                    f.write(line)
416
417
418def parse_tests(targets, main_tests, specific_test=None):
419    """
420    Generate test runners, extract and return path to unit test(s).
421    Also parse generated test runners to avoid warning: redundant redeclaration.
422    Return True if successful.
423    """
424    test_found = False
425    directory = 'TestCases'
426
427    if specific_test and '/' in specific_test:
428        specific_test = specific_test.strip(directory).replace('/', '')
429
430    for dir in next(os.walk(directory))[1]:
431        if re.search(r'test_arm', dir):
432            if specific_test and dir != specific_test:
433                continue
434            test_found = True
435            testpath = directory + '/' + dir + '/Unity/'
436            ut_test_file = None
437            for content in os.listdir(testpath):
438                if re.search(r'unity_test_arm', content):
439                    ut_test_file = content
440            if ut_test_file is None:
441                print("Warning: invalid path: ", testpath)
442                continue
443            main_tests.append(testpath)
444            ut_test_file_runner = path.splitext(ut_test_file)[0] + '_runner' + path.splitext(ut_test_file)[1]
445            test_code = testpath + ut_test_file
446            test_runner_path = testpath + 'TestRunner/'
447            if not os.path.exists(test_runner_path):
448                os.mkdir(test_runner_path)
449            test_runner = test_runner_path + ut_test_file_runner
450            for old_files in glob.glob(test_runner_path + '/*'):
451                if not old_files.endswith('readme.txt'):
452                    os.remove(old_files)
453
454            # Generate test runners
455            run_command('ruby '+UNITY_PATH+'auto/generate_test_runner.rb ' + test_code + ' ' + test_runner)
456            test_found = parse_test(test_runner, targets)
457            if not test_found:
458                return False
459
460            parse_generated_test_runner(test_runner)
461
462    if not test_found:
463        return False
464    return True
465
466
467def parse_test(test_runner, targets):
468    tests_found = False
469
470    # Get list of tests
471    try:
472        read_file = open(test_runner, "r")
473    except IOError as e:
474        error_handler(170, e)
475    else:
476        with read_file:
477            for line in read_file:
478                if not line:
479                    break
480                if re.search(r"  run_test\(", line) and len(line.strip().split(',')) == 3:
481                    function = line.strip().split(',')[0].split('(')[1]
482                    tests_found = True
483                    for target in targets:
484                        if 'tests' not in target.keys():
485                            target['tests'] = []
486                        target["tests"].append(function)
487                        target[function] = {}
488                        target[function]["pass"] = False
489                        target[function]["tested"] = False
490    return tests_found
491
492
493if __name__ == '__main__':
494    args = parse_args()
495    sys.exit(test_targets(args))
496