1 /*
2  * Copyright (c) 2023, Meta
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "thrd.h"
8 
9 #include <stdint.h>
10 #include <threads.h>
11 
12 #include <zephyr/ztest.h>
13 
14 static tss_t key;
15 static int32_t destroyed_values[2];
16 static const int32_t forty_two = FORTY_TWO;
17 static const int32_t seventy_three = SEVENTY_THREE;
18 
19 static thrd_t thread1;
20 static thrd_t thread2;
21 
destroy_fn(void * arg)22 static void destroy_fn(void *arg)
23 {
24 	int32_t val = *(int32_t *)arg;
25 
26 	switch (val) {
27 	case FORTY_TWO:
28 		destroyed_values[0] = FORTY_TWO;
29 		break;
30 	case SEVENTY_THREE:
31 		destroyed_values[1] = SEVENTY_THREE;
32 		break;
33 	default:
34 		zassert_true(val == FORTY_TWO || val == SEVENTY_THREE, "unexpected val %d", val);
35 	}
36 }
37 
ZTEST(libc_tss,test_tss_create_delete)38 ZTEST(libc_tss, test_tss_create_delete)
39 {
40 	/* degenerate test cases */
41 	if (false) {
42 		/* pthread_key_create() has not been hardened against this */
43 		zassert_equal(thrd_error, tss_create(NULL, NULL));
44 		zassert_equal(thrd_error, tss_create(NULL, destroy_fn));
45 	}
46 	tss_delete(BIOS_FOOD);
47 
48 	/* happy path tested in before() / after() */
49 }
50 
thread_fn(void * arg)51 static int thread_fn(void *arg)
52 {
53 	int32_t val = *(int32_t *)arg;
54 
55 	zassert_equal(tss_get(key), NULL);
56 	tss_set(key, arg);
57 	zassert_equal(tss_get(key), arg);
58 
59 	return val;
60 }
61 
62 /* test out separate threads doing tss_get() / tss_set() */
ZTEST(libc_tss,test_tss_get_set)63 ZTEST(libc_tss, test_tss_get_set)
64 {
65 	int res1 = BIOS_FOOD;
66 	int res2 = BIOS_FOOD;
67 
68 	/* degenerate test cases */
69 	zassert_is_null(tss_get(BIOS_FOOD));
70 	zassert_not_equal(thrd_success, tss_set(FORTY_TWO, (void *)BIOS_FOOD));
71 	zassert_is_null(tss_get(FORTY_TWO));
72 
73 	zassert_equal(thrd_success, thrd_create(&thread1, thread_fn, (void *)&forty_two));
74 	zassert_equal(thrd_success, thrd_create(&thread2, thread_fn, (void *)&seventy_three));
75 
76 	zassert_equal(thrd_success, thrd_join(thread1, &res1));
77 	zassert_equal(thrd_success, thrd_join(thread2, &res2));
78 	zassert_equal(FORTY_TWO, res1);
79 	zassert_equal(SEVENTY_THREE, res2);
80 
81 	zassert_equal(destroyed_values[0], FORTY_TWO);
82 	zassert_equal(destroyed_values[1], SEVENTY_THREE);
83 }
84 
before(void * arg)85 static void before(void *arg)
86 {
87 	destroyed_values[0] = 0;
88 	destroyed_values[1] = 0;
89 
90 	zassert_equal(thrd_success, tss_create(&key, destroy_fn));
91 }
92 
after(void * arg)93 static void after(void *arg)
94 {
95 	tss_delete(key);
96 }
97 
98 ZTEST_SUITE(libc_tss, NULL, NULL, before, after, NULL);
99