1#!/usr/bin/env python3 2# 3# SPDX-FileCopyrightText: Copyright 2010-2020, 2022 Arm Limited and/or its affiliates <open-source-office@arm.com> 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Licensed under the Apache License, Version 2.0 (the License); you may 8# not use this file except in compliance with the License. 9# You may obtain a copy of the License at 10# 11# www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, software 14# distributed under the License is distributed on an AS IS BASIS, WITHOUT 15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16# See the License for the specific language governing permissions and 17# limitations under the License. 18# 19import os 20import re 21import sys 22import json 23import copy 24import glob 25import time 26import queue 27import shutil 28import serial 29import argparse 30import threading 31import subprocess 32 33from os import path 34from termcolor import colored 35 36OUTPUT = "Output/" 37BASE_PATH = "../../" 38UNITY_PATH = "Unity/" 39UNITY_BASE = BASE_PATH + UNITY_PATH 40UNITY_SRC = UNITY_BASE + "src/" 41 42 43def parse_args(): 44 parser = argparse.ArgumentParser(description="Run CMSIS-NN unit tests.", 45 epilog="Runs on all connected HW supported by Mbed.") 46 parser.add_argument('--testdir', type=str, default='TESTRUN', help="prefix of output dir name") 47 parser.add_argument('-s', '--specific-test', type=str, default=None, help="Run a specific test, e.g." 48 " -s TestCases/test_arm_avgpool_s8 (also this form will work: -s test_arm_avgpool_s8)." 49 " So basically the different options can be listed with:" 50 " ls -d TestCases/test_* -1") 51 parser.add_argument('-c', '--compiler', type=str, default='GCC_ARM', choices=['GCC_ARM', 'ARMC6']) 52 parser.add_argument('--download-and-generate-test-runners', dest='download_and_generate', action='store_true', 53 help="Just download Unity and generate test runners if needed") 54 55 required_arguments = parser.add_argument_group('required named arguments') 56 required_arguments.add_argument('-p', '--path-to-cmsis', help='Path to CMSIS project for CMSIS/Core.', 57 required=True) 58 59 args = parser.parse_args() 60 return args 61 62 63def error_handler(code, text=None): 64 print("Error: {}".format(text)) 65 sys.exit(code) 66 67 68def detect_targets(targets): 69 process = subprocess.Popen(['mbedls'], 70 stdout=subprocess.PIPE, 71 stderr=subprocess.PIPE, 72 universal_newlines=True) 73 print(process.stdout.readline().strip()) 74 while True: 75 line = process.stdout.readline() 76 print(line.strip()) 77 if not line: 78 break 79 if re.search(r"^\| ", line): 80 words = (line.split('| ')) 81 target = {"model": words[1].strip(), 82 "name": words[2].strip()[:-1].replace('[', '_'), 83 "port": words[4].strip(), 84 "tid": words[5].strip()} # Target id can be used to filter out targets 85 targets.append(target) 86 return_code = process.poll() 87 if return_code != 0: 88 error_handler(return_code, 'RETURN CODE {}'.format(process.stderr.read())) 89 90 91def run_command(command, error_msg=None, die=True): 92 # TODO handle error: 93 # cp: error writing '/media/mannil01/NODE_F411RE/TESTRUN_NUCLEO_F411RE_GCC_ARM.bin': No space left on device 94 # https://os.mbed.com/questions/59636/STM-Nucleo-No-space-left-on-device-when-/ 95 96 # print(command) 97 command_list = command.split(' ') 98 process = subprocess.run(command_list) 99 if die and process.returncode != 0: 100 error_handler(process.returncode, error_msg) 101 return process.returncode 102 103 104def detect_architecture(target_name, target_json): 105 arch = None 106 107 try: 108 with open(target_json, "r") as read_file: 109 data = json.load(read_file) 110 111 if 'core' in data[target_name]: 112 core = data[target_name]['core'] 113 elif 'inherits' in data[target_name]: 114 target_inherits = data[target_name]['inherits'][0] 115 core = data[target_inherits]['core'] 116 else: 117 raise Exception("Cannot detect architecture") 118 119 if core: 120 arch = core[:9] 121 if core[:8] == 'Cortex-M': 122 return arch 123 error_handler(168, 'Unsupported target: {} with architecture: {}'.format( 124 target_name, arch)) 125 except Exception as e: 126 error_handler(167, e) 127 128 return arch 129 130 131def test_target(target, args, main_test): 132 result = 3 133 compiler = args.compiler 134 cmsis_path = args.path_to_cmsis 135 cmsis_nn_path = "../../../../" 136 target_name = target['name'] 137 target_model = target['model'] 138 unittestframework = 'UNITY_UNITTEST' 139 140 dir_name = OUTPUT + args.testdir + '_' + unittestframework + '_' + target_name + '_' + compiler 141 142 os.makedirs(dir_name, exist_ok=True) 143 start_dir = os.getcwd() 144 os.chdir(dir_name) 145 146 cmsis_core_include_path = cmsis_path + '/CMSIS/Core/Include/' 147 relative_cmsis_core_include_path = os.path.relpath(cmsis_core_include_path, '.') 148 149 try: 150 target_json = 'mbed-os/targets/targets.json' 151 mbed_path = BASE_PATH + 'Mbed/' 152 153 if not path.exists("mbed-os.lib"): 154 print("Initializing mbed in {}".format(os.getcwd())) 155 shutil.copyfile(mbed_path + 'mbed-os.lib', 'mbed-os.lib') 156 shutil.copyfile(mbed_path + 'mbed_app.json', 'mbed_app.json') 157 run_command('mbed config root .') 158 run_command('mbed deploy') 159 arch = detect_architecture(target_model, target_json) 160 161 print("----------------------------------------------------------------") 162 print("Running {} on {} target: {} with compiler: {} in directory: {} test: {}\n".format( 163 unittestframework, arch, target_name, compiler, os.getcwd(), main_test)) 164 165 die = False 166 flash_error_msg = 'failed to flash' 167 mbed_command = "compile" 168 test = '' 169 additional_options = ' --source ' + BASE_PATH + main_test + \ 170 ' --source ' + UNITY_SRC + \ 171 ' --profile ' + mbed_path + 'release.json' + \ 172 ' -f' 173 174 result = run_command("mbed {} -v -m ".format(mbed_command) + target_model + ' -t ' + compiler + 175 test + 176 ' --source .' 177 ' --source ' + BASE_PATH + 'TestCases/Utils/' 178 ' --source ' + cmsis_nn_path + 'Include/' 179 ' --source ' + relative_cmsis_core_include_path + 180 ' --source ' + cmsis_nn_path + 'Source/ConvolutionFunctions/' 181 ' --source ' + cmsis_nn_path + 'Source/PoolingFunctions/' 182 ' --source ' + cmsis_nn_path + 'Source/NNSupportFunctions/' 183 ' --source ' + cmsis_nn_path + 'Source/FullyConnectedFunctions/' 184 ' --source ' + cmsis_nn_path + 'Source/SoftmaxFunctions/' 185 ' --source ' + cmsis_nn_path + 'Source/SVDFunctions/' 186 ' --source ' + cmsis_nn_path + 'Source/BasicMathFunctions/' 187 + additional_options, 188 flash_error_msg, die=die) 189 190 except Exception as e: 191 error_handler(166, e) 192 193 os.chdir(start_dir) 194 return result 195 196 197def read_serial_port(ser, inputQueue, stop): 198 while True: 199 if stop(): 200 break 201 line = ser.readline() 202 inputQueue.put(line.decode('latin-1').strip()) 203 204 205def test_target_with_unity(target, args, main_test): 206 port = target['port'] 207 stop_thread = False 208 baudrate = 9600 209 timeout = 30 210 inputQueue = queue.Queue() 211 tests = copy.deepcopy(target["tests"]) 212 213 try: 214 ser = serial.Serial(port, baudrate, timeout=timeout) 215 except Exception as e: 216 error_handler(169, "serial exception: {}".format(e)) 217 218 # Clear read buffer 219 time.sleep(0.1) # Workaround in response to: open() returns before port is ready 220 ser.reset_input_buffer() 221 222 serial_thread = threading.Thread(target=read_serial_port, args=(ser, inputQueue, lambda: stop_thread), daemon=True) 223 serial_thread.start() 224 225 test_target(target, args, main_test) 226 227 start_time = time.time() 228 while time.time() < start_time + timeout: 229 if inputQueue.qsize() > 0: 230 str_line = inputQueue.get() 231 print(str_line) 232 test = None 233 try: 234 test = str_line.split(':')[2] 235 test_result = ':'.join(str_line.split(':')[2:4]) 236 except IndexError: 237 pass 238 if test in tests: 239 tests.remove(test) 240 target[test]["tested"] = True 241 if test_result == test + ':PASS': 242 target[test]["pass"] = True 243 if len(tests) == 0: 244 break 245 246 stop_thread = True 247 serial_thread.join() 248 ser.close() 249 250 251def print_summary(targets): 252 """ 253 Return 0 if all test passed 254 Return 1 if all test completed but one or more failed 255 Return 2 if one or more tests did not complete or was not detected 256 """ 257 passed = 0 258 failed = 0 259 tested = 0 260 expected = 0 261 return_code = 3 262 verdict_pass = colored('[ PASSED ]', 'green') 263 verdict_fail = colored('[ FAILED ]', 'red') 264 verdict_error = colored('[ ERROR ]', 'red') 265 266 print("-----------------------------------------------------------------------------------------------------------") 267 268 # Find all passed and failed 269 for target in targets: 270 for test in target["tests"]: 271 expected += 1 272 if target[test]["tested"]: 273 tested += 1 274 if target[test]["pass"]: 275 passed += 1 276 else: 277 failed += 1 278 279 if tested != expected: 280 print("{} Not all tests found".format(verdict_error)) 281 print("{} Expected: {} Actual: {}".format(verdict_error, expected, tested)) 282 return_code = 2 283 elif tested == passed: 284 return_code = 0 285 else: 286 return_code = 1 287 288 # print all test cases 289 sorted_tc = [] 290 for target in targets: 291 for test in target["tests"]: 292 if not target[test]["tested"]: 293 tc_verdict = verdict_error 294 elif target[test]["pass"]: 295 tc_verdict = verdict_pass 296 else: 297 tc_verdict = verdict_fail 298 sorted_tc.append("{} {}: {}".format(tc_verdict, target["name"], test)) 299 sorted_tc.sort() 300 for tc in sorted_tc: 301 print(tc) 302 303 total = 0 304 if (passed > 0): 305 total = passed / expected 306 if (total == 1.0): 307 verdict = verdict_pass 308 else: 309 verdict = verdict_fail 310 print("{} Summary: {} tests in total passed on {} target(s) ({})". 311 format(verdict, passed, len(targets), ', '.join([t['name'] for t in targets]))) 312 print("{} {:.0f}% tests passed, {} tests failed out of {}".format(verdict, total*100, failed, expected)) 313 314 return return_code 315 316 317def test_targets(args): 318 """ 319 Return 0 if successful 320 Return 3 if no targets are detected 321 Return 4 if no tests are found 322 """ 323 result = 0 324 targets = [] 325 main_tests = [] 326 327 if not args.download_and_generate: 328 detect_targets(targets) 329 if len(targets) == 0: 330 print("No targets detected!") 331 return 3 332 333 download_unity() 334 335 if not parse_tests(targets, main_tests, args.specific_test): 336 print("No tests found?!") 337 return 4 338 339 if args.download_and_generate: 340 return result 341 342 for target in targets: 343 for tst in main_tests: 344 test_target_with_unity(target, args, tst) 345 346 result = print_summary(targets) 347 348 return result 349 350 351def download_unity(force=False): 352 unity_dir = UNITY_PATH 353 unity_src = unity_dir+"src/" 354 process = subprocess.run(['mktemp'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) 355 download_dir = process.stdout.strip() 356 run_command("rm -f {}".format(download_dir)) 357 download_dir += '/' 358 359 # Check if already downloaded 360 if not force and path.isdir(unity_dir) and path.isfile(unity_src+"unity.c") and path.isfile(unity_src+"unity.h"): 361 return 362 363 if path.isdir(download_dir): 364 shutil.rmtree(download_dir) 365 if path.isdir(unity_dir): 366 shutil.rmtree(unity_dir) 367 os.mkdir(unity_dir) 368 os.makedirs(download_dir, exist_ok=False) 369 current_dir = os.getcwd() 370 os.chdir(download_dir) 371 process = subprocess.Popen('curl -LJ https://api.github.com/repos/ThrowTheSwitch/Unity/tarball/v2.5.0 --output unity_tarball.tar.gz'.split(), 372 stdout=subprocess.PIPE, 373 stderr=subprocess.PIPE, 374 universal_newlines=True) 375 for line in process.stderr: 376 print(line.strip()) 377 print() 378 for line in process.stdout: 379 pass 380 if not line: 381 error_handler(171) 382 downloaded_file = download_dir + "unity_tarball.tar.gz" 383 os.chdir(current_dir) 384 try: 385 filename_base = downloaded_file.split('-')[0] 386 except IndexError as e: 387 error_handler(174, e) 388 if not filename_base: 389 error_handler(175) 390 run_command("tar xzf "+downloaded_file+" -C "+unity_dir+" --strip-components=1") 391 os.chdir(current_dir) 392 393 # Cleanup 394 shutil.rmtree(download_dir) 395 396 397def parse_generated_test_runner(test_runner): 398 parsed_functions = ['setUp', 'tearDown', 'resetTest', 'verifyTest'] 399 400 def is_func_to_parse(func): 401 for f in parsed_functions: 402 if f in func: 403 return True 404 return False 405 406 with open(test_runner, "r") as f: 407 lines = f.readlines() 408 with open(test_runner, "w") as f: 409 for line in lines: 410 sline = line.strip('\n') 411 if not re.search(r"\(void\);", sline): 412 f.write(line) 413 else: 414 if not is_func_to_parse(sline): 415 f.write(line) 416 417 418def parse_tests(targets, main_tests, specific_test=None): 419 """ 420 Generate test runners, extract and return path to unit test(s). 421 Also parse generated test runners to avoid warning: redundant redeclaration. 422 Return True if successful. 423 """ 424 test_found = False 425 directory = 'TestCases' 426 427 if specific_test and '/' in specific_test: 428 specific_test = specific_test.strip(directory).replace('/', '') 429 430 for dir in next(os.walk(directory))[1]: 431 if re.search(r'test_arm', dir): 432 if specific_test and dir != specific_test: 433 continue 434 test_found = True 435 testpath = directory + '/' + dir + '/Unity/' 436 ut_test_file = None 437 for content in os.listdir(testpath): 438 if re.search(r'unity_test_arm', content): 439 ut_test_file = content 440 if ut_test_file is None: 441 print("Warning: invalid path: ", testpath) 442 continue 443 main_tests.append(testpath) 444 ut_test_file_runner = path.splitext(ut_test_file)[0] + '_runner' + path.splitext(ut_test_file)[1] 445 test_code = testpath + ut_test_file 446 test_runner_path = testpath + 'TestRunner/' 447 if not os.path.exists(test_runner_path): 448 os.mkdir(test_runner_path) 449 test_runner = test_runner_path + ut_test_file_runner 450 for old_files in glob.glob(test_runner_path + '/*'): 451 if not old_files.endswith('readme.txt'): 452 os.remove(old_files) 453 454 # Generate test runners 455 run_command('ruby '+UNITY_PATH+'auto/generate_test_runner.rb ' + test_code + ' ' + test_runner) 456 test_found = parse_test(test_runner, targets) 457 if not test_found: 458 return False 459 460 parse_generated_test_runner(test_runner) 461 462 if not test_found: 463 return False 464 return True 465 466 467def parse_test(test_runner, targets): 468 tests_found = False 469 470 # Get list of tests 471 try: 472 read_file = open(test_runner, "r") 473 except IOError as e: 474 error_handler(170, e) 475 else: 476 with read_file: 477 for line in read_file: 478 if not line: 479 break 480 if re.search(r" run_test\(", line) and len(line.strip().split(',')) == 3: 481 function = line.strip().split(',')[0].split('(')[1] 482 tests_found = True 483 for target in targets: 484 if 'tests' not in target.keys(): 485 target['tests'] = [] 486 target["tests"].append(function) 487 target[function] = {} 488 target[function]["pass"] = False 489 target[function]["tested"] = False 490 return tests_found 491 492 493if __name__ == '__main__': 494 args = parse_args() 495 sys.exit(test_targets(args)) 496