1 /*
2  * Copyright (c) 2023 Intel Corporation.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "zephyr/irq.h"
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <zephyr/sys/util_loops.h>
11 #include <zephyr/timing/timing.h>
12 #include <zephyr/sys/spsc_lockfree.h>
13 #include <zephyr/sys/mpsc_lockfree.h>
14 
15 static struct mpsc push_pop_q;
16 static struct mpsc_node push_pop_nodes[2];
17 
18 /*
19  * @brief Push and pop one element
20  *
21  * @see mpsc_push(), mpsc_pop()
22  *
23  * @ingroup tests
24  */
ZTEST(mpsc,test_push_pop)25 ZTEST(mpsc, test_push_pop)
26 {
27 
28 	mpsc_ptr_t node, head;
29 	struct mpsc_node *stub, *next, *tail;
30 
31 	mpsc_init(&push_pop_q);
32 
33 	head = mpsc_ptr_get(push_pop_q.head);
34 	tail = push_pop_q.tail;
35 	stub = &push_pop_q.stub;
36 	next = stub->next;
37 
38 	zassert_equal(head, stub, "Head should point at stub");
39 	zassert_equal(tail, stub, "Tail should point at stub");
40 	zassert_is_null(next, "Next should be null");
41 
42 	node = mpsc_pop(&push_pop_q);
43 	zassert_is_null(node, "Pop on empty queue should return null");
44 
45 	mpsc_push(&push_pop_q, &push_pop_nodes[0]);
46 
47 	head = mpsc_ptr_get(push_pop_q.head);
48 
49 	zassert_equal(head, &push_pop_nodes[0], "Queue head should point at push_pop_node");
50 	next = mpsc_ptr_get(push_pop_nodes[0].next);
51 	zassert_is_null(next, "push_pop_node next should point at null");
52 	next = mpsc_ptr_get(push_pop_q.stub.next);
53 	zassert_equal(next, &push_pop_nodes[0], "Queue stub should point at push_pop_node");
54 	tail = push_pop_q.tail;
55 	stub = &push_pop_q.stub;
56 	zassert_equal(tail, stub, "Tail should point at stub");
57 
58 	node = mpsc_pop(&push_pop_q);
59 	stub = &push_pop_q.stub;
60 
61 	zassert_not_equal(node, stub, "Pop should not return stub");
62 	zassert_not_null(node, "Pop should not return null");
63 	zassert_equal(node, &push_pop_nodes[0],
64 		      "Pop should return push_pop_node %p, instead was %p",
65 		      &push_pop_nodes[0], node);
66 
67 	node = mpsc_pop(&push_pop_q);
68 	zassert_is_null(node, "Pop on empty queue should return null");
69 }
70 
71 #define MPSC_FREEQ_SZ 8
72 #define MPSC_ITERATIONS 100000
73 #define MPSC_STACK_SIZE (512 + CONFIG_TEST_EXTRA_STACK_SIZE)
74 #define MPSC_THREADS_NUM 4
75 
76 struct thread_info {
77 	k_tid_t tid;
78 	int executed;
79 	int priority;
80 	int cpu_id;
81 };
82 
83 static struct thread_info mpsc_tinfo[MPSC_THREADS_NUM];
84 static struct k_thread mpsc_thread[MPSC_THREADS_NUM];
85 static K_THREAD_STACK_ARRAY_DEFINE(mpsc_stack, MPSC_THREADS_NUM, MPSC_STACK_SIZE);
86 
87 struct test_mpsc_node {
88 	uint32_t id;
89 	struct mpsc_node n;
90 };
91 
92 
93 struct spsc_node_sq {
94 	struct spsc _spsc;
95 	struct test_mpsc_node *const buffer;
96 };
97 
98 #define TEST_SPSC_DEFINE(n, sz) SPSC_DEFINE(_spsc_##n, struct test_mpsc_node, sz)
99 #define SPSC_NAME(n, _) (struct spsc_node_sq *)&_spsc_##n
100 
101 LISTIFY(MPSC_THREADS_NUM, TEST_SPSC_DEFINE, (;), MPSC_FREEQ_SZ)
102 
103 struct spsc_node_sq *node_q[MPSC_THREADS_NUM] = {
104 	LISTIFY(MPSC_THREADS_NUM, SPSC_NAME, (,))
105 };
106 
107 static struct mpsc mpsc_q;
108 
mpsc_consumer(void * p1,void * p2,void * p3)109 static void mpsc_consumer(void *p1, void *p2, void *p3)
110 {
111 	ARG_UNUSED(p1);
112 	ARG_UNUSED(p2);
113 	ARG_UNUSED(p3);
114 
115 	struct mpsc_node *n;
116 	struct test_mpsc_node *nn;
117 
118 	for (int i = 0; i < (MPSC_ITERATIONS)*(MPSC_THREADS_NUM - 1); i++) {
119 		do {
120 			n = mpsc_pop(&mpsc_q);
121 			if (n == NULL) {
122 				k_yield();
123 			}
124 		} while (n == NULL);
125 
126 		zassert_not_equal(n, &mpsc_q.stub, "mpsc should not produce stub");
127 
128 		nn = CONTAINER_OF(n, struct test_mpsc_node, n);
129 
130 		/* Return node to producer's free queue - must retry if queue is full */
131 		while (spsc_acquire(node_q[nn->id]) == NULL) {
132 			k_yield();
133 		}
134 		spsc_produce(node_q[nn->id]);
135 	}
136 }
137 
mpsc_producer(void * p1,void * p2,void * p3)138 static void mpsc_producer(void *p1, void *p2, void *p3)
139 {
140 	ARG_UNUSED(p1);
141 	ARG_UNUSED(p2);
142 	ARG_UNUSED(p3);
143 
144 	struct test_mpsc_node *n;
145 	uint32_t id = (uint32_t)(uintptr_t)p1;
146 
147 	for (int i = 0; i < MPSC_ITERATIONS; i++) {
148 		do {
149 			n = spsc_consume(node_q[id]);
150 			if (n == NULL) {
151 				k_yield();
152 			}
153 		} while (n == NULL);
154 
155 		spsc_release(node_q[id]);
156 		n->id = id;
157 		mpsc_push(&mpsc_q, &n->n);
158 	}
159 }
160 
161 /**
162  * @brief Test that the producer and consumer are indeed thread safe
163  *
164  * This can and should be validated on SMP machines where incoherent
165  * memory could cause issues.
166  */
ZTEST(mpsc,test_mpsc_threaded)167 ZTEST(mpsc, test_mpsc_threaded)
168 {
169 	mpsc_init(&mpsc_q);
170 
171 	TC_PRINT("setting up mpsc producer free queues\n");
172 	/* Setup node free queues */
173 	for (int i = 0; i < MPSC_THREADS_NUM; i++) {
174 		for (int j = 0; j < MPSC_FREEQ_SZ; j++) {
175 			spsc_acquire(node_q[i]);
176 		}
177 		spsc_produce_all(node_q[i]);
178 	}
179 
180 	TC_PRINT("starting consumer\n");
181 	mpsc_tinfo[0].tid =
182 		k_thread_create(&mpsc_thread[0], mpsc_stack[0], MPSC_STACK_SIZE,
183 				mpsc_consumer,
184 				NULL, NULL, NULL,
185 				K_PRIO_PREEMPT(5),
186 				K_INHERIT_PERMS, K_NO_WAIT);
187 
188 	for (int i = 1; i < MPSC_THREADS_NUM; i++) {
189 		TC_PRINT("starting producer %i\n", i);
190 		mpsc_tinfo[i].tid =
191 			k_thread_create(&mpsc_thread[i], mpsc_stack[i], MPSC_STACK_SIZE,
192 					mpsc_producer,
193 					(void *)(uintptr_t)i, NULL, NULL,
194 					K_PRIO_PREEMPT(5),
195 					K_INHERIT_PERMS, K_NO_WAIT);
196 	}
197 
198 	for (int i = 0; i < MPSC_THREADS_NUM; i++) {
199 		TC_PRINT("joining mpsc thread %d\n", i);
200 		k_thread_join(mpsc_tinfo[i].tid, K_FOREVER);
201 	}
202 }
203 
204 #define THROUGHPUT_ITERS 100000
205 
ZTEST(mpsc,test_mpsc_throughput)206 ZTEST(mpsc, test_mpsc_throughput)
207 {
208 	struct mpsc_node node;
209 	timing_t start_time, end_time;
210 
211 	mpsc_init(&mpsc_q);
212 	timing_init();
213 	timing_start();
214 
215 	start_time = timing_counter_get();
216 
217 	int key = irq_lock();
218 
219 	for (int i = 0; i < THROUGHPUT_ITERS; i++) {
220 		mpsc_push(&mpsc_q, &node);
221 
222 		mpsc_pop(&mpsc_q);
223 	}
224 
225 	irq_unlock(key);
226 
227 	end_time = timing_counter_get();
228 
229 	uint64_t cycles = timing_cycles_get(&start_time, &end_time);
230 	uint64_t ns = timing_cycles_to_ns(cycles);
231 
232 	TC_PRINT("%llu ns for %d iterations, %llu ns per op\n", ns,
233 		 THROUGHPUT_ITERS, ns/THROUGHPUT_ITERS);
234 }
235 
236 ZTEST_SUITE(mpsc, NULL, NULL, NULL, NULL, NULL);
237