1#!/usr/bin/env python3
2#
3# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Licensed under the Apache License, Version 2.0 (the License); you may
8# not use this file except in compliance with the License.
9# You may obtain a copy of the License at
10#
11# www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing, software
14# distributed under the License is distributed on an AS IS BASIS, WITHOUT
15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16# See the License for the specific language governing permissions and
17# limitations under the License.
18#
19import os
20import re
21import sys
22import json
23import copy
24import glob
25import time
26import queue
27import shutil
28import serial
29import argparse
30import threading
31import subprocess
32
33from os import path
34from termcolor import colored
35
36OUTPUT = "Output/"
37BASE_PATH = "../../"
38UNITY_PATH = "Unity/"
39UNITY_BASE = BASE_PATH + UNITY_PATH
40UNITY_SRC = UNITY_BASE + "src/"
41
42
43def parse_args():
44    parser = argparse.ArgumentParser(description="Run CMSIS-NN unit tests.",
45                                     epilog="Runs on all connected HW supported by Mbed.")
46    parser.add_argument('--testdir', type=str, default='TESTRUN', help="prefix of output dir name")
47    parser.add_argument('-s',
48                        '--specific-test',
49                        type=str,
50                        default=None,
51                        help="Run a specific test, e.g."
52                        " -s TestCases/test_arm_avgpool_s8 (also this form will work: -s test_arm_avgpool_s8)."
53                        " So basically the different options can be listed with:"
54                        " ls -d TestCases/test_* -1")
55    parser.add_argument('-c', '--compiler', type=str, default='GCC_ARM', choices=['GCC_ARM', 'ARMC6'])
56    parser.add_argument('--download-and-generate-test-runners',
57                        dest='download_and_generate',
58                        action='store_true',
59                        help="Just download Unity and generate test runners if needed")
60
61    required_arguments = parser.add_argument_group('required named arguments')
62
63    args = parser.parse_args()
64    return args
65
66
67def error_handler(code, text=None):
68    print("Error: {}".format(text))
69    sys.exit(code)
70
71
72def detect_targets(targets):
73    process = subprocess.Popen(['mbedls'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
74    print(process.stdout.readline().strip())
75    while True:
76        line = process.stdout.readline()
77        print(line.strip())
78        if not line:
79            break
80        if re.search(r"^\| ", line):
81            words = (line.split('| '))
82            target = {
83                "model": words[1].strip(),
84                "name": words[2].strip()[:-1].replace('[', '_'),
85                "port": words[4].strip(),
86                "tid": words[5].strip()
87            }  # Target id can be used to filter out targets
88            targets.append(target)
89    return_code = process.poll()
90    if return_code != 0:
91        error_handler(return_code, 'RETURN CODE {}'.format(process.stderr.read()))
92
93
94def run_command(command, error_msg=None, die=True):
95    # TODO handle error:
96    # cp: error writing '/media/mannil01/NODE_F411RE/TESTRUN_NUCLEO_F411RE_GCC_ARM.bin': No space left on device
97    # https://os.mbed.com/questions/59636/STM-Nucleo-No-space-left-on-device-when-/
98
99    # print(command)
100    command_list = command.split(' ')
101    process = subprocess.run(command_list)
102    if die and process.returncode != 0:
103        error_handler(process.returncode, error_msg)
104    return process.returncode
105
106
107def detect_architecture(target_name, target_json):
108    arch = None
109    try:
110        with open(target_json, "r") as read_file:
111            data = json.load(read_file)
112
113            while ('core' not in data[target_name]):
114                print(f"{target_name} inherits from {data[target_name]['inherits'][0]}")
115                target_name = data[target_name]['inherits'][0]
116            core = data[target_name]['core']
117
118            if core:
119                arch = core[:9]
120                if core[:8] == 'Cortex-M':
121                    return arch
122            error_handler(168, 'Unsupported target: {} with architecture: {}'.format(target_name, arch))
123    except Exception as e:
124        error_handler(167, e)
125
126    return arch
127
128
129def test_target(target, args, main_test):
130    result = 3
131    compiler = args.compiler
132    cmsis_nn_path = "../../../../"
133    target_name = target['name']
134    target_model = target['model']
135    unittestframework = 'UNITY_UNITTEST'
136
137    dir_name = OUTPUT + args.testdir + '_' + unittestframework + '_' + target_name + '_' + compiler
138
139    os.makedirs(dir_name, exist_ok=True)
140    start_dir = os.getcwd()
141    os.chdir(dir_name)
142
143    try:
144        target_json = 'mbed-os/targets/targets.json'
145        mbed_path = BASE_PATH + 'Mbed/'
146
147        if not path.exists("mbed-os.lib"):
148            print("Initializing mbed in {}".format(os.getcwd()))
149            run_command(f'cp -a {mbed_path}. .')
150            run_command('mbed deploy')
151            run_command(f'rm -rf mbed-os/TESTS' + ' mbed-os/UNITTESTS' + ' mbed-os/docker_images' + ' mbed-os/docs' +
152                        ' mbed-os/extern' + ' mbed-os/features')
153        arch = detect_architecture(target_model, target_json)
154
155        print("----------------------------------------------------------------")
156        print("Running {} on {} target: {} with compiler: {} in directory: {} test: {}\n".format(
157            unittestframework, arch, target_name, compiler, os.getcwd(), main_test))
158
159        die = False
160        flash_error_msg = 'failed to flash'
161        mbed_command = "compile"
162        test = ''
163        additional_options = ' --source ' + BASE_PATH + main_test + \
164                             ' --source ' + UNITY_SRC + \
165                             ' --profile ' + mbed_path + 'release.json' + \
166                             ' -f'
167
168        result = run_command("mbed {} -v -m ".format(mbed_command) + target_model + ' -t ' + compiler + test +
169                             ' --source .'
170                             ' --source ' + BASE_PATH + 'TestCases/Utils/'
171                             ' --source ' + cmsis_nn_path + 'Include/'
172                             ' --source ' + cmsis_nn_path + 'Source/ConvolutionFunctions/'
173                             ' --source ' + cmsis_nn_path + 'Source/PoolingFunctions/'
174                             ' --source ' + cmsis_nn_path + 'Source/NNSupportFunctions/'
175                             ' --source ' + cmsis_nn_path + 'Source/FullyConnectedFunctions/'
176                             ' --source ' + cmsis_nn_path + 'Source/SoftmaxFunctions/'
177                             ' --source ' + cmsis_nn_path + 'Source/SVDFunctions/'
178                             ' --source ' + cmsis_nn_path + 'Source/BasicMathFunctions/'
179                             ' --source ' + cmsis_nn_path + 'Source/ActivationFunctions/'
180                             ' --source ' + cmsis_nn_path + 'Source/LSTMFunctions/' + additional_options,
181                             flash_error_msg,
182                             die=die)
183
184    except Exception as e:
185        error_handler(166, e)
186
187    os.chdir(start_dir)
188    return result
189
190
191def read_serial_port(ser, inputQueue, stop):
192    while True:
193        if stop():
194            break
195        line = ser.readline()
196        inputQueue.put(line.decode('latin-1').strip())
197
198
199def test_target_with_unity(target, args, main_test):
200    port = target['port']
201    stop_thread = False
202    baudrate = 9600
203    timeout = 30
204    inputQueue = queue.Queue()
205    tests = copy.deepcopy(target["tests"])
206
207    try:
208        ser = serial.Serial(port, baudrate, timeout=timeout)
209    except Exception as e:
210        error_handler(169, "serial exception: {}".format(e))
211
212    # Clear read buffer
213    time.sleep(0.1)  # Workaround in response to: open() returns before port is ready
214    ser.reset_input_buffer()
215
216    serial_thread = threading.Thread(target=read_serial_port, args=(ser, inputQueue, lambda: stop_thread), daemon=True)
217    serial_thread.start()
218
219    test_target(target, args, main_test)
220
221    start_time = time.time()
222    while time.time() < start_time + timeout:
223        if inputQueue.qsize() > 0:
224            str_line = inputQueue.get()
225            print(str_line)
226            test = None
227            try:
228                test = str_line.split(':')[2]
229                test_result = ':'.join(str_line.split(':')[2:4])
230            except IndexError:
231                pass
232            if test in tests:
233                tests.remove(test)
234                target[test]["tested"] = True
235                if test_result == test + ':PASS':
236                    target[test]["pass"] = True
237            if len(tests) == 0:
238                break
239
240    stop_thread = True
241    serial_thread.join()
242    ser.close()
243
244
245def print_summary(targets):
246    """
247    Return 0 if all test passed
248    Return 1 if all test completed but one or more failed
249    Return 2 if one or more tests did not complete or was not detected
250    """
251    passed = 0
252    failed = 0
253    tested = 0
254    expected = 0
255    return_code = 3
256    verdict_pass = colored('[ PASSED ]', 'green')
257    verdict_fail = colored('[ FAILED ]', 'red')
258    verdict_error = colored('[ ERROR ]', 'red')
259
260    print("-----------------------------------------------------------------------------------------------------------")
261
262    # Find all passed and failed
263    for target in targets:
264        for test in target["tests"]:
265            expected += 1
266            if target[test]["tested"]:
267                tested += 1
268            if target[test]["pass"]:
269                passed += 1
270            else:
271                failed += 1
272
273    if tested != expected:
274        print("{} Not all tests found".format(verdict_error))
275        print("{} Expected: {} Actual: {}".format(verdict_error, expected, tested))
276        return_code = 2
277    elif tested == passed:
278        return_code = 0
279    else:
280        return_code = 1
281
282    # print all test cases
283    sorted_tc = []
284    for target in targets:
285        for test in target["tests"]:
286            if not target[test]["tested"]:
287                tc_verdict = verdict_error
288            elif target[test]["pass"]:
289                tc_verdict = verdict_pass
290            else:
291                tc_verdict = verdict_fail
292            sorted_tc.append("{} {}: {}".format(tc_verdict, target["name"], test))
293    sorted_tc.sort()
294    for tc in sorted_tc:
295        print(tc)
296
297    total = 0
298    if (passed > 0):
299        total = passed / expected
300    if (total == 1.0):
301        verdict = verdict_pass
302    else:
303        verdict = verdict_fail
304    print("{} Summary: {} tests in total passed on {} target(s) ({})".format(verdict, passed, len(targets),
305                                                                             ', '.join([t['name'] for t in targets])))
306    print("{} {:.0f}% tests passed, {} tests failed out of {}".format(verdict, total * 100, failed, expected))
307
308    return return_code
309
310
311def test_targets(args):
312    """
313    Return 0 if successful
314    Return 3 if no targets are detected
315    Return 4 if no tests are found
316    """
317    result = 0
318    targets = []
319    main_tests = []
320
321    if not args.download_and_generate:
322        detect_targets(targets)
323        if len(targets) == 0:
324            print("No targets detected!")
325            return 3
326
327    download_unity()
328
329    if not parse_tests(targets, main_tests, args.specific_test):
330        print("No tests found?!")
331        return 4
332
333    if args.download_and_generate:
334        return result
335
336    for target in targets:
337        for tst in main_tests:
338            test_target_with_unity(target, args, tst)
339
340    result = print_summary(targets)
341
342    return result
343
344
345def download_unity(force=False):
346    unity_dir = UNITY_PATH
347    unity_src = unity_dir + "src/"
348    process = subprocess.run(['mktemp'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
349    download_dir = process.stdout.strip()
350    run_command("rm -f {}".format(download_dir))
351    download_dir += '/'
352
353    # Check if already downloaded
354    if not force and path.isdir(unity_dir) and path.isfile(unity_src + "unity.c") and path.isfile(unity_src +
355                                                                                                  "unity.h"):
356        return
357
358    if path.isdir(download_dir):
359        shutil.rmtree(download_dir)
360    if path.isdir(unity_dir):
361        shutil.rmtree(unity_dir)
362    os.mkdir(unity_dir)
363    os.makedirs(download_dir, exist_ok=False)
364    current_dir = os.getcwd()
365    os.chdir(download_dir)
366    process = subprocess.Popen(
367        'curl -LJ https://api.github.com/repos/ThrowTheSwitch/Unity/tarball/v2.5.0 --output unity_tarball.tar.gz'.split(
368        ),
369        stdout=subprocess.PIPE,
370        stderr=subprocess.PIPE,
371        universal_newlines=True)
372    for line in process.stderr:
373        print(line.strip())
374    print()
375    for line in process.stdout:
376        pass
377    if not line:
378        error_handler(171)
379    downloaded_file = download_dir + "unity_tarball.tar.gz"
380    os.chdir(current_dir)
381    try:
382        filename_base = downloaded_file.split('-')[0]
383    except IndexError as e:
384        error_handler(174, e)
385    if not filename_base:
386        error_handler(175)
387    run_command("tar xzf " + downloaded_file + " -C " + unity_dir + " --strip-components=1")
388    os.chdir(current_dir)
389
390    # Cleanup
391    shutil.rmtree(download_dir)
392
393
394def parse_generated_test_runner(test_runner):
395    parsed_functions = ['setUp', 'tearDown', 'resetTest', 'verifyTest']
396
397    def is_func_to_parse(func):
398        for f in parsed_functions:
399            if f in func:
400                return True
401        return False
402
403    with open(test_runner, "r") as f:
404        lines = f.readlines()
405    with open(test_runner, "w") as f:
406        for line in lines:
407            sline = line.strip('\n')
408            if not re.search(r"\(void\);", sline):
409                f.write(line)
410            else:
411                if not is_func_to_parse(sline):
412                    f.write(line)
413
414
415def parse_tests(targets, main_tests, specific_test=None):
416    """
417    Generate test runners, extract and return path to unit test(s).
418    Also parse generated test runners to avoid warning: redundant redeclaration.
419    Return True if successful.
420    """
421    test_found = False
422    directory = 'TestCases'
423
424    if specific_test and '/' in specific_test:
425        specific_test = specific_test.strip(directory).replace('/', '')
426
427    for dir in next(os.walk(directory))[1]:
428        if re.search(r'test_arm', dir):
429            if specific_test and dir != specific_test:
430                continue
431            test_found = True
432            testpath = directory + '/' + dir + '/Unity/'
433            ut_test_file = None
434            for content in os.listdir(testpath):
435                if re.search(r'unity_test_arm', content):
436                    ut_test_file = content
437            if ut_test_file is None:
438                print("Warning: invalid path: ", testpath)
439                continue
440            main_tests.append(testpath)
441            ut_test_file_runner = path.splitext(ut_test_file)[0] + '_runner' + path.splitext(ut_test_file)[1]
442            test_code = testpath + ut_test_file
443            test_runner_path = testpath + 'TestRunner/'
444            if not os.path.exists(test_runner_path):
445                os.mkdir(test_runner_path)
446            test_runner = test_runner_path + ut_test_file_runner
447            for old_files in glob.glob(test_runner_path + '/*'):
448                if not old_files.endswith('readme.txt'):
449                    os.remove(old_files)
450
451            # Generate test runners
452            run_command('ruby ' + UNITY_PATH + 'auto/generate_test_runner.rb ' + test_code + ' ' + test_runner)
453            test_found = parse_test(test_runner, targets)
454            if not test_found:
455                return False
456
457            parse_generated_test_runner(test_runner)
458
459    if not test_found:
460        return False
461    return True
462
463
464def parse_test(test_runner, targets):
465    tests_found = False
466
467    # Get list of tests
468    try:
469        read_file = open(test_runner, "r")
470    except IOError as e:
471        error_handler(170, e)
472    else:
473        with read_file:
474            for line in read_file:
475                if not line:
476                    break
477                if re.search(r"  run_test\(", line) and len(line.strip().split(',')) == 3:
478                    function = line.strip().split(',')[0].split('(')[1]
479                    tests_found = True
480                    for target in targets:
481                        if 'tests' not in target.keys():
482                            target['tests'] = []
483                        target["tests"].append(function)
484                        target[function] = {}
485                        target[function]["pass"] = False
486                        target[function]["tested"] = False
487    return tests_found
488
489
490if __name__ == '__main__':
491    args = parse_args()
492    sys.exit(test_targets(args))
493