1 /*
2 * Copyright (c) 2023 Intel Corporation.
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include "zephyr/irq.h"
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <zephyr/sys/util_loops.h>
11 #include <zephyr/timing/timing.h>
12 #include <zephyr/sys/spsc_lockfree.h>
13 #include <zephyr/sys/mpsc_lockfree.h>
14
15 static struct mpsc push_pop_q;
16 static struct mpsc_node push_pop_nodes[2];
17
18 /*
19 * @brief Push and pop one element
20 *
21 * @see mpsc_push(), mpsc_pop()
22 *
23 * @ingroup tests
24 */
ZTEST(mpsc,test_push_pop)25 ZTEST(mpsc, test_push_pop)
26 {
27
28 mpsc_ptr_t node, head;
29 struct mpsc_node *stub, *next, *tail;
30
31 mpsc_init(&push_pop_q);
32
33 head = mpsc_ptr_get(push_pop_q.head);
34 tail = push_pop_q.tail;
35 stub = &push_pop_q.stub;
36 next = stub->next;
37
38 zassert_equal(head, stub, "Head should point at stub");
39 zassert_equal(tail, stub, "Tail should point at stub");
40 zassert_is_null(next, "Next should be null");
41
42 node = mpsc_pop(&push_pop_q);
43 zassert_is_null(node, "Pop on empty queue should return null");
44
45 mpsc_push(&push_pop_q, &push_pop_nodes[0]);
46
47 head = mpsc_ptr_get(push_pop_q.head);
48
49 zassert_equal(head, &push_pop_nodes[0], "Queue head should point at push_pop_node");
50 next = mpsc_ptr_get(push_pop_nodes[0].next);
51 zassert_is_null(next, "push_pop_node next should point at null");
52 next = mpsc_ptr_get(push_pop_q.stub.next);
53 zassert_equal(next, &push_pop_nodes[0], "Queue stub should point at push_pop_node");
54 tail = push_pop_q.tail;
55 stub = &push_pop_q.stub;
56 zassert_equal(tail, stub, "Tail should point at stub");
57
58 node = mpsc_pop(&push_pop_q);
59 stub = &push_pop_q.stub;
60
61 zassert_not_equal(node, stub, "Pop should not return stub");
62 zassert_not_null(node, "Pop should not return null");
63 zassert_equal(node, &push_pop_nodes[0],
64 "Pop should return push_pop_node %p, instead was %p",
65 &push_pop_nodes[0], node);
66
67 node = mpsc_pop(&push_pop_q);
68 zassert_is_null(node, "Pop on empty queue should return null");
69 }
70
71 #define MPSC_FREEQ_SZ 8
72 #define MPSC_ITERATIONS 100000
73 #define MPSC_STACK_SIZE (512 + CONFIG_TEST_EXTRA_STACK_SIZE)
74 #define MPSC_THREADS_NUM 4
75
76 struct thread_info {
77 k_tid_t tid;
78 int executed;
79 int priority;
80 int cpu_id;
81 };
82
83 static struct thread_info mpsc_tinfo[MPSC_THREADS_NUM];
84 static struct k_thread mpsc_thread[MPSC_THREADS_NUM];
85 static K_THREAD_STACK_ARRAY_DEFINE(mpsc_stack, MPSC_THREADS_NUM, MPSC_STACK_SIZE);
86
87 struct test_mpsc_node {
88 uint32_t id;
89 struct mpsc_node n;
90 };
91
92
93 struct spsc_node_sq {
94 struct spsc _spsc;
95 struct test_mpsc_node *const buffer;
96 };
97
98 #define TEST_SPSC_DEFINE(n, sz) SPSC_DEFINE(_spsc_##n, struct test_mpsc_node, sz)
99 #define SPSC_NAME(n, _) (struct spsc_node_sq *)&_spsc_##n
100
101 LISTIFY(MPSC_THREADS_NUM, TEST_SPSC_DEFINE, (;), MPSC_FREEQ_SZ)
102
103 struct spsc_node_sq *node_q[MPSC_THREADS_NUM] = {
104 LISTIFY(MPSC_THREADS_NUM, SPSC_NAME, (,))
105 };
106
107 static struct mpsc mpsc_q;
108
mpsc_consumer(void * p1,void * p2,void * p3)109 static void mpsc_consumer(void *p1, void *p2, void *p3)
110 {
111 ARG_UNUSED(p1);
112 ARG_UNUSED(p2);
113 ARG_UNUSED(p3);
114
115 struct mpsc_node *n;
116 struct test_mpsc_node *nn;
117
118 for (int i = 0; i < (MPSC_ITERATIONS)*(MPSC_THREADS_NUM - 1); i++) {
119 do {
120 n = mpsc_pop(&mpsc_q);
121 if (n == NULL) {
122 k_yield();
123 }
124 } while (n == NULL);
125
126 zassert_not_equal(n, &mpsc_q.stub, "mpsc should not produce stub");
127
128 nn = CONTAINER_OF(n, struct test_mpsc_node, n);
129
130 /* Return node to producer's free queue - must retry if queue is full */
131 while (spsc_acquire(node_q[nn->id]) == NULL) {
132 k_yield();
133 }
134 spsc_produce(node_q[nn->id]);
135 }
136 }
137
mpsc_producer(void * p1,void * p2,void * p3)138 static void mpsc_producer(void *p1, void *p2, void *p3)
139 {
140 ARG_UNUSED(p1);
141 ARG_UNUSED(p2);
142 ARG_UNUSED(p3);
143
144 struct test_mpsc_node *n;
145 uint32_t id = (uint32_t)(uintptr_t)p1;
146
147 for (int i = 0; i < MPSC_ITERATIONS; i++) {
148 do {
149 n = spsc_consume(node_q[id]);
150 if (n == NULL) {
151 k_yield();
152 }
153 } while (n == NULL);
154
155 spsc_release(node_q[id]);
156 n->id = id;
157 mpsc_push(&mpsc_q, &n->n);
158 }
159 }
160
161 /**
162 * @brief Test that the producer and consumer are indeed thread safe
163 *
164 * This can and should be validated on SMP machines where incoherent
165 * memory could cause issues.
166 */
ZTEST(mpsc,test_mpsc_threaded)167 ZTEST(mpsc, test_mpsc_threaded)
168 {
169 mpsc_init(&mpsc_q);
170
171 TC_PRINT("setting up mpsc producer free queues\n");
172 /* Setup node free queues */
173 for (int i = 0; i < MPSC_THREADS_NUM; i++) {
174 for (int j = 0; j < MPSC_FREEQ_SZ; j++) {
175 spsc_acquire(node_q[i]);
176 }
177 spsc_produce_all(node_q[i]);
178 }
179
180 TC_PRINT("starting consumer\n");
181 mpsc_tinfo[0].tid =
182 k_thread_create(&mpsc_thread[0], mpsc_stack[0], MPSC_STACK_SIZE,
183 mpsc_consumer,
184 NULL, NULL, NULL,
185 K_PRIO_PREEMPT(5),
186 K_INHERIT_PERMS, K_NO_WAIT);
187
188 for (int i = 1; i < MPSC_THREADS_NUM; i++) {
189 TC_PRINT("starting producer %i\n", i);
190 mpsc_tinfo[i].tid =
191 k_thread_create(&mpsc_thread[i], mpsc_stack[i], MPSC_STACK_SIZE,
192 mpsc_producer,
193 (void *)(uintptr_t)i, NULL, NULL,
194 K_PRIO_PREEMPT(5),
195 K_INHERIT_PERMS, K_NO_WAIT);
196 }
197
198 for (int i = 0; i < MPSC_THREADS_NUM; i++) {
199 TC_PRINT("joining mpsc thread %d\n", i);
200 k_thread_join(mpsc_tinfo[i].tid, K_FOREVER);
201 }
202 }
203
204 #define THROUGHPUT_ITERS 100000
205
ZTEST(mpsc,test_mpsc_throughput)206 ZTEST(mpsc, test_mpsc_throughput)
207 {
208 struct mpsc_node node;
209 timing_t start_time, end_time;
210
211 mpsc_init(&mpsc_q);
212 timing_init();
213 timing_start();
214
215 start_time = timing_counter_get();
216
217 int key = irq_lock();
218
219 for (int i = 0; i < THROUGHPUT_ITERS; i++) {
220 mpsc_push(&mpsc_q, &node);
221
222 mpsc_pop(&mpsc_q);
223 }
224
225 irq_unlock(key);
226
227 end_time = timing_counter_get();
228
229 uint64_t cycles = timing_cycles_get(&start_time, &end_time);
230 uint64_t ns = timing_cycles_to_ns(cycles);
231
232 TC_PRINT("%llu ns for %d iterations, %llu ns per op\n", ns,
233 THROUGHPUT_ITERS, ns/THROUGHPUT_ITERS);
234 }
235
236 ZTEST_SUITE(mpsc, NULL, NULL, NULL, NULL, NULL);
237