1#!/usr/bin/env python3 2# 3# SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Licensed under the Apache License, Version 2.0 (the License); you may 8# not use this file except in compliance with the License. 9# You may obtain a copy of the License at 10# 11# www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, software 14# distributed under the License is distributed on an AS IS BASIS, WITHOUT 15# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16# See the License for the specific language governing permissions and 17# limitations under the License. 18# 19import os 20import re 21import sys 22import json 23import copy 24import glob 25import time 26import queue 27import shutil 28import serial 29import argparse 30import threading 31import subprocess 32 33from os import path 34from termcolor import colored 35 36OUTPUT = "Output/" 37BASE_PATH = "../../" 38UNITY_PATH = "Unity/" 39UNITY_BASE = BASE_PATH + UNITY_PATH 40UNITY_SRC = UNITY_BASE + "src/" 41 42 43def parse_args(): 44 parser = argparse.ArgumentParser(description="Run CMSIS-NN unit tests.", 45 epilog="Runs on all connected HW supported by Mbed.") 46 parser.add_argument('--testdir', type=str, default='TESTRUN', help="prefix of output dir name") 47 parser.add_argument('-s', 48 '--specific-test', 49 type=str, 50 default=None, 51 help="Run a specific test, e.g." 52 " -s TestCases/test_arm_avgpool_s8 (also this form will work: -s test_arm_avgpool_s8)." 53 " So basically the different options can be listed with:" 54 " ls -d TestCases/test_* -1") 55 parser.add_argument('-c', '--compiler', type=str, default='GCC_ARM', choices=['GCC_ARM', 'ARMC6']) 56 parser.add_argument('--download-and-generate-test-runners', 57 dest='download_and_generate', 58 action='store_true', 59 help="Just download Unity and generate test runners if needed") 60 61 required_arguments = parser.add_argument_group('required named arguments') 62 63 args = parser.parse_args() 64 return args 65 66 67def error_handler(code, text=None): 68 print("Error: {}".format(text)) 69 sys.exit(code) 70 71 72def detect_targets(targets): 73 process = subprocess.Popen(['mbedls'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) 74 print(process.stdout.readline().strip()) 75 while True: 76 line = process.stdout.readline() 77 print(line.strip()) 78 if not line: 79 break 80 if re.search(r"^\| ", line): 81 words = (line.split('| ')) 82 target = { 83 "model": words[1].strip(), 84 "name": words[2].strip()[:-1].replace('[', '_'), 85 "port": words[4].strip(), 86 "tid": words[5].strip() 87 } # Target id can be used to filter out targets 88 targets.append(target) 89 return_code = process.poll() 90 if return_code != 0: 91 error_handler(return_code, 'RETURN CODE {}'.format(process.stderr.read())) 92 93 94def run_command(command, error_msg=None, die=True): 95 # TODO handle error: 96 # cp: error writing '/media/mannil01/NODE_F411RE/TESTRUN_NUCLEO_F411RE_GCC_ARM.bin': No space left on device 97 # https://os.mbed.com/questions/59636/STM-Nucleo-No-space-left-on-device-when-/ 98 99 # print(command) 100 command_list = command.split(' ') 101 process = subprocess.run(command_list) 102 if die and process.returncode != 0: 103 error_handler(process.returncode, error_msg) 104 return process.returncode 105 106 107def detect_architecture(target_name, target_json): 108 arch = None 109 try: 110 with open(target_json, "r") as read_file: 111 data = json.load(read_file) 112 113 while ('core' not in data[target_name]): 114 print(f"{target_name} inherits from {data[target_name]['inherits'][0]}") 115 target_name = data[target_name]['inherits'][0] 116 core = data[target_name]['core'] 117 118 if core: 119 arch = core[:9] 120 if core[:8] == 'Cortex-M': 121 return arch 122 error_handler(168, 'Unsupported target: {} with architecture: {}'.format(target_name, arch)) 123 except Exception as e: 124 error_handler(167, e) 125 126 return arch 127 128 129def test_target(target, args, main_test): 130 result = 3 131 compiler = args.compiler 132 cmsis_nn_path = "../../../../" 133 target_name = target['name'] 134 target_model = target['model'] 135 unittestframework = 'UNITY_UNITTEST' 136 137 dir_name = OUTPUT + args.testdir + '_' + unittestframework + '_' + target_name + '_' + compiler 138 139 os.makedirs(dir_name, exist_ok=True) 140 start_dir = os.getcwd() 141 os.chdir(dir_name) 142 143 try: 144 target_json = 'mbed-os/targets/targets.json' 145 mbed_path = BASE_PATH + 'Mbed/' 146 147 if not path.exists("mbed-os.lib"): 148 print("Initializing mbed in {}".format(os.getcwd())) 149 run_command(f'cp -a {mbed_path}. .') 150 run_command('mbed deploy') 151 run_command(f'rm -rf mbed-os/TESTS' + ' mbed-os/UNITTESTS' + ' mbed-os/docker_images' + ' mbed-os/docs' + 152 ' mbed-os/extern' + ' mbed-os/features') 153 arch = detect_architecture(target_model, target_json) 154 155 print("----------------------------------------------------------------") 156 print("Running {} on {} target: {} with compiler: {} in directory: {} test: {}\n".format( 157 unittestframework, arch, target_name, compiler, os.getcwd(), main_test)) 158 159 die = False 160 flash_error_msg = 'failed to flash' 161 mbed_command = "compile" 162 test = '' 163 additional_options = ' --source ' + BASE_PATH + main_test + \ 164 ' --source ' + UNITY_SRC + \ 165 ' --profile ' + mbed_path + 'release.json' + \ 166 ' -f' 167 168 result = run_command("mbed {} -v -m ".format(mbed_command) + target_model + ' -t ' + compiler + test + 169 ' --source .' 170 ' --source ' + BASE_PATH + 'TestCases/Utils/' 171 ' --source ' + cmsis_nn_path + 'Include/' 172 ' --source ' + cmsis_nn_path + 'Source/ConvolutionFunctions/' 173 ' --source ' + cmsis_nn_path + 'Source/PoolingFunctions/' 174 ' --source ' + cmsis_nn_path + 'Source/NNSupportFunctions/' 175 ' --source ' + cmsis_nn_path + 'Source/FullyConnectedFunctions/' 176 ' --source ' + cmsis_nn_path + 'Source/SoftmaxFunctions/' 177 ' --source ' + cmsis_nn_path + 'Source/SVDFunctions/' 178 ' --source ' + cmsis_nn_path + 'Source/BasicMathFunctions/' 179 ' --source ' + cmsis_nn_path + 'Source/ActivationFunctions/' 180 ' --source ' + cmsis_nn_path + 'Source/LSTMFunctions/' + additional_options, 181 flash_error_msg, 182 die=die) 183 184 except Exception as e: 185 error_handler(166, e) 186 187 os.chdir(start_dir) 188 return result 189 190 191def read_serial_port(ser, inputQueue, stop): 192 while True: 193 if stop(): 194 break 195 line = ser.readline() 196 inputQueue.put(line.decode('latin-1').strip()) 197 198 199def test_target_with_unity(target, args, main_test): 200 port = target['port'] 201 stop_thread = False 202 baudrate = 9600 203 timeout = 30 204 inputQueue = queue.Queue() 205 tests = copy.deepcopy(target["tests"]) 206 207 try: 208 ser = serial.Serial(port, baudrate, timeout=timeout) 209 except Exception as e: 210 error_handler(169, "serial exception: {}".format(e)) 211 212 # Clear read buffer 213 time.sleep(0.1) # Workaround in response to: open() returns before port is ready 214 ser.reset_input_buffer() 215 216 serial_thread = threading.Thread(target=read_serial_port, args=(ser, inputQueue, lambda: stop_thread), daemon=True) 217 serial_thread.start() 218 219 test_target(target, args, main_test) 220 221 start_time = time.time() 222 while time.time() < start_time + timeout: 223 if inputQueue.qsize() > 0: 224 str_line = inputQueue.get() 225 print(str_line) 226 test = None 227 try: 228 test = str_line.split(':')[2] 229 test_result = ':'.join(str_line.split(':')[2:4]) 230 except IndexError: 231 pass 232 if test in tests: 233 tests.remove(test) 234 target[test]["tested"] = True 235 if test_result == test + ':PASS': 236 target[test]["pass"] = True 237 if len(tests) == 0: 238 break 239 240 stop_thread = True 241 serial_thread.join() 242 ser.close() 243 244 245def print_summary(targets): 246 """ 247 Return 0 if all test passed 248 Return 1 if all test completed but one or more failed 249 Return 2 if one or more tests did not complete or was not detected 250 """ 251 passed = 0 252 failed = 0 253 tested = 0 254 expected = 0 255 return_code = 3 256 verdict_pass = colored('[ PASSED ]', 'green') 257 verdict_fail = colored('[ FAILED ]', 'red') 258 verdict_error = colored('[ ERROR ]', 'red') 259 260 print("-----------------------------------------------------------------------------------------------------------") 261 262 # Find all passed and failed 263 for target in targets: 264 for test in target["tests"]: 265 expected += 1 266 if target[test]["tested"]: 267 tested += 1 268 if target[test]["pass"]: 269 passed += 1 270 else: 271 failed += 1 272 273 if tested != expected: 274 print("{} Not all tests found".format(verdict_error)) 275 print("{} Expected: {} Actual: {}".format(verdict_error, expected, tested)) 276 return_code = 2 277 elif tested == passed: 278 return_code = 0 279 else: 280 return_code = 1 281 282 # print all test cases 283 sorted_tc = [] 284 for target in targets: 285 for test in target["tests"]: 286 if not target[test]["tested"]: 287 tc_verdict = verdict_error 288 elif target[test]["pass"]: 289 tc_verdict = verdict_pass 290 else: 291 tc_verdict = verdict_fail 292 sorted_tc.append("{} {}: {}".format(tc_verdict, target["name"], test)) 293 sorted_tc.sort() 294 for tc in sorted_tc: 295 print(tc) 296 297 total = 0 298 if (passed > 0): 299 total = passed / expected 300 if (total == 1.0): 301 verdict = verdict_pass 302 else: 303 verdict = verdict_fail 304 print("{} Summary: {} tests in total passed on {} target(s) ({})".format(verdict, passed, len(targets), 305 ', '.join([t['name'] for t in targets]))) 306 print("{} {:.0f}% tests passed, {} tests failed out of {}".format(verdict, total * 100, failed, expected)) 307 308 return return_code 309 310 311def test_targets(args): 312 """ 313 Return 0 if successful 314 Return 3 if no targets are detected 315 Return 4 if no tests are found 316 """ 317 result = 0 318 targets = [] 319 main_tests = [] 320 321 if not args.download_and_generate: 322 detect_targets(targets) 323 if len(targets) == 0: 324 print("No targets detected!") 325 return 3 326 327 download_unity() 328 329 if not parse_tests(targets, main_tests, args.specific_test): 330 print("No tests found?!") 331 return 4 332 333 if args.download_and_generate: 334 return result 335 336 for target in targets: 337 for tst in main_tests: 338 test_target_with_unity(target, args, tst) 339 340 result = print_summary(targets) 341 342 return result 343 344 345def download_unity(force=False): 346 unity_dir = UNITY_PATH 347 unity_src = unity_dir + "src/" 348 process = subprocess.run(['mktemp'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) 349 download_dir = process.stdout.strip() 350 run_command("rm -f {}".format(download_dir)) 351 download_dir += '/' 352 353 # Check if already downloaded 354 if not force and path.isdir(unity_dir) and path.isfile(unity_src + "unity.c") and path.isfile(unity_src + 355 "unity.h"): 356 return 357 358 if path.isdir(download_dir): 359 shutil.rmtree(download_dir) 360 if path.isdir(unity_dir): 361 shutil.rmtree(unity_dir) 362 os.mkdir(unity_dir) 363 os.makedirs(download_dir, exist_ok=False) 364 current_dir = os.getcwd() 365 os.chdir(download_dir) 366 process = subprocess.Popen( 367 'curl -LJ https://api.github.com/repos/ThrowTheSwitch/Unity/tarball/v2.5.0 --output unity_tarball.tar.gz'.split( 368 ), 369 stdout=subprocess.PIPE, 370 stderr=subprocess.PIPE, 371 universal_newlines=True) 372 for line in process.stderr: 373 print(line.strip()) 374 print() 375 for line in process.stdout: 376 pass 377 if not line: 378 error_handler(171) 379 downloaded_file = download_dir + "unity_tarball.tar.gz" 380 os.chdir(current_dir) 381 try: 382 filename_base = downloaded_file.split('-')[0] 383 except IndexError as e: 384 error_handler(174, e) 385 if not filename_base: 386 error_handler(175) 387 run_command("tar xzf " + downloaded_file + " -C " + unity_dir + " --strip-components=1") 388 os.chdir(current_dir) 389 390 # Cleanup 391 shutil.rmtree(download_dir) 392 393 394def parse_generated_test_runner(test_runner): 395 parsed_functions = ['setUp', 'tearDown', 'resetTest', 'verifyTest'] 396 397 def is_func_to_parse(func): 398 for f in parsed_functions: 399 if f in func: 400 return True 401 return False 402 403 with open(test_runner, "r") as f: 404 lines = f.readlines() 405 with open(test_runner, "w") as f: 406 for line in lines: 407 sline = line.strip('\n') 408 if not re.search(r"\(void\);", sline): 409 f.write(line) 410 else: 411 if not is_func_to_parse(sline): 412 f.write(line) 413 414 415def parse_tests(targets, main_tests, specific_test=None): 416 """ 417 Generate test runners, extract and return path to unit test(s). 418 Also parse generated test runners to avoid warning: redundant redeclaration. 419 Return True if successful. 420 """ 421 test_found = False 422 directory = 'TestCases' 423 424 if specific_test and '/' in specific_test: 425 specific_test = specific_test.strip(directory).replace('/', '') 426 427 for dir in next(os.walk(directory))[1]: 428 if re.search(r'test_arm', dir): 429 if specific_test and dir != specific_test: 430 continue 431 test_found = True 432 testpath = directory + '/' + dir + '/Unity/' 433 ut_test_file = None 434 for content in os.listdir(testpath): 435 if re.search(r'unity_test_arm', content): 436 ut_test_file = content 437 if ut_test_file is None: 438 print("Warning: invalid path: ", testpath) 439 continue 440 main_tests.append(testpath) 441 ut_test_file_runner = path.splitext(ut_test_file)[0] + '_runner' + path.splitext(ut_test_file)[1] 442 test_code = testpath + ut_test_file 443 test_runner_path = testpath + 'TestRunner/' 444 if not os.path.exists(test_runner_path): 445 os.mkdir(test_runner_path) 446 test_runner = test_runner_path + ut_test_file_runner 447 for old_files in glob.glob(test_runner_path + '/*'): 448 if not old_files.endswith('readme.txt'): 449 os.remove(old_files) 450 451 # Generate test runners 452 run_command('ruby ' + UNITY_PATH + 'auto/generate_test_runner.rb ' + test_code + ' ' + test_runner) 453 test_found = parse_test(test_runner, targets) 454 if not test_found: 455 return False 456 457 parse_generated_test_runner(test_runner) 458 459 if not test_found: 460 return False 461 return True 462 463 464def parse_test(test_runner, targets): 465 tests_found = False 466 467 # Get list of tests 468 try: 469 read_file = open(test_runner, "r") 470 except IOError as e: 471 error_handler(170, e) 472 else: 473 with read_file: 474 for line in read_file: 475 if not line: 476 break 477 if re.search(r" run_test\(", line) and len(line.strip().split(',')) == 3: 478 function = line.strip().split(',')[0].split('(')[1] 479 tests_found = True 480 for target in targets: 481 if 'tests' not in target.keys(): 482 target['tests'] = [] 483 target["tests"].append(function) 484 target[function] = {} 485 target[function]["pass"] = False 486 target[function]["tested"] = False 487 return tests_found 488 489 490if __name__ == '__main__': 491 args = parse_args() 492 sys.exit(test_targets(args)) 493