1 /*
2  * Copyright (c) 2023 Intel Corporation.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "zephyr/irq.h"
8 #include <zephyr/ztest.h>
9 #include <zephyr/kernel.h>
10 #include <zephyr/sys/util_loops.h>
11 #include <zephyr/timing/timing.h>
12 #include <zephyr/sys/spsc_lockfree.h>
13 #include <zephyr/sys/mpsc_lockfree.h>
14 
15 static struct mpsc push_pop_q;
16 static struct mpsc_node push_pop_nodes[2];
17 
18 /*
19  * @brief Push and pop one element
20  *
21  * @see mpsc_push(), mpsc_pop()
22  *
23  * @ingroup tests
24  */
ZTEST(mpsc,test_push_pop)25 ZTEST(mpsc, test_push_pop)
26 {
27 
28 	mpsc_ptr_t node, head;
29 	struct mpsc_node *stub, *next, *tail;
30 
31 	mpsc_init(&push_pop_q);
32 
33 	head = mpsc_ptr_get(push_pop_q.head);
34 	tail = push_pop_q.tail;
35 	stub = &push_pop_q.stub;
36 	next = stub->next;
37 
38 	zassert_equal(head, stub, "Head should point at stub");
39 	zassert_equal(tail, stub, "Tail should point at stub");
40 	zassert_is_null(next, "Next should be null");
41 
42 	node = mpsc_pop(&push_pop_q);
43 	zassert_is_null(node, "Pop on empty queue should return null");
44 
45 	mpsc_push(&push_pop_q, &push_pop_nodes[0]);
46 
47 	head = mpsc_ptr_get(push_pop_q.head);
48 
49 	zassert_equal(head, &push_pop_nodes[0], "Queue head should point at push_pop_node");
50 	next = mpsc_ptr_get(push_pop_nodes[0].next);
51 	zassert_is_null(next, NULL, "push_pop_node next should point at null");
52 	next = mpsc_ptr_get(push_pop_q.stub.next);
53 	zassert_equal(next, &push_pop_nodes[0], "Queue stub should point at push_pop_node");
54 	tail = push_pop_q.tail;
55 	stub = &push_pop_q.stub;
56 	zassert_equal(tail, stub, "Tail should point at stub");
57 
58 	node = mpsc_pop(&push_pop_q);
59 	stub = &push_pop_q.stub;
60 
61 	zassert_not_equal(node, stub, "Pop should not return stub");
62 	zassert_not_null(node, "Pop should not return null");
63 	zassert_equal(node, &push_pop_nodes[0],
64 		      "Pop should return push_pop_node %p, instead was %p",
65 		      &push_pop_nodes[0], node);
66 
67 	node = mpsc_pop(&push_pop_q);
68 	zassert_is_null(node, "Pop on empty queue should return null");
69 }
70 
71 #define MPSC_FREEQ_SZ 8
72 #define MPSC_ITERATIONS 100000
73 #define MPSC_STACK_SIZE (512 + CONFIG_TEST_EXTRA_STACK_SIZE)
74 #define MPSC_THREADS_NUM 4
75 
76 struct thread_info {
77 	k_tid_t tid;
78 	int executed;
79 	int priority;
80 	int cpu_id;
81 };
82 
83 static struct thread_info mpsc_tinfo[MPSC_THREADS_NUM];
84 static struct k_thread mpsc_thread[MPSC_THREADS_NUM];
85 static K_THREAD_STACK_ARRAY_DEFINE(mpsc_stack, MPSC_THREADS_NUM, MPSC_STACK_SIZE);
86 
87 struct test_mpsc_node {
88 	uint32_t id;
89 	struct mpsc_node n;
90 };
91 
92 
93 struct spsc_node_sq {
94 	struct spsc _spsc;
95 	struct test_mpsc_node *const buffer;
96 };
97 
98 #define TEST_SPSC_DEFINE(n, sz) SPSC_DEFINE(_spsc_##n, struct test_mpsc_node, sz)
99 #define SPSC_NAME(n, _) (struct spsc_node_sq *)&_spsc_##n
100 
101 LISTIFY(MPSC_THREADS_NUM, TEST_SPSC_DEFINE, (;), MPSC_FREEQ_SZ)
102 
103 struct spsc_node_sq *node_q[MPSC_THREADS_NUM] = {
104 	LISTIFY(MPSC_THREADS_NUM, SPSC_NAME, (,))
105 };
106 
107 static struct mpsc mpsc_q;
108 
mpsc_consumer(void * p1,void * p2,void * p3)109 static void mpsc_consumer(void *p1, void *p2, void *p3)
110 {
111 	ARG_UNUSED(p1);
112 	ARG_UNUSED(p2);
113 	ARG_UNUSED(p3);
114 
115 	struct mpsc_node *n;
116 	struct test_mpsc_node *nn;
117 
118 	for (int i = 0; i < (MPSC_ITERATIONS)*(MPSC_THREADS_NUM - 1); i++) {
119 		do {
120 			n = mpsc_pop(&mpsc_q);
121 			if (n == NULL) {
122 				k_yield();
123 			}
124 		} while (n == NULL);
125 
126 		zassert_not_equal(n, &mpsc_q.stub, "mpsc should not produce stub");
127 
128 		nn = CONTAINER_OF(n, struct test_mpsc_node, n);
129 
130 		spsc_acquire(node_q[nn->id]);
131 		spsc_produce(node_q[nn->id]);
132 	}
133 }
134 
mpsc_producer(void * p1,void * p2,void * p3)135 static void mpsc_producer(void *p1, void *p2, void *p3)
136 {
137 	ARG_UNUSED(p1);
138 	ARG_UNUSED(p2);
139 	ARG_UNUSED(p3);
140 
141 	struct test_mpsc_node *n;
142 	uint32_t id = (uint32_t)(uintptr_t)p1;
143 
144 	for (int i = 0; i < MPSC_ITERATIONS; i++) {
145 		do {
146 			n = spsc_consume(node_q[id]);
147 			if (n == NULL) {
148 				k_yield();
149 			}
150 		} while (n == NULL);
151 
152 		spsc_release(node_q[id]);
153 		n->id = id;
154 		mpsc_push(&mpsc_q, &n->n);
155 	}
156 }
157 
158 /**
159  * @brief Test that the producer and consumer are indeed thread safe
160  *
161  * This can and should be validated on SMP machines where incoherent
162  * memory could cause issues.
163  */
ZTEST(mpsc,test_mpsc_threaded)164 ZTEST(mpsc, test_mpsc_threaded)
165 {
166 	mpsc_init(&mpsc_q);
167 
168 	TC_PRINT("setting up mpsc producer free queues\n");
169 	/* Setup node free queues */
170 	for (int i = 0; i < MPSC_THREADS_NUM; i++) {
171 		for (int j = 0; j < MPSC_FREEQ_SZ; j++) {
172 			spsc_acquire(node_q[i]);
173 		}
174 		spsc_produce_all(node_q[i]);
175 	}
176 
177 	TC_PRINT("starting consumer\n");
178 	mpsc_tinfo[0].tid =
179 		k_thread_create(&mpsc_thread[0], mpsc_stack[0], MPSC_STACK_SIZE,
180 				mpsc_consumer,
181 				NULL, NULL, NULL,
182 				K_PRIO_PREEMPT(5),
183 				K_INHERIT_PERMS, K_NO_WAIT);
184 
185 	for (int i = 1; i < MPSC_THREADS_NUM; i++) {
186 		TC_PRINT("starting producer %i\n", i);
187 		mpsc_tinfo[i].tid =
188 			k_thread_create(&mpsc_thread[i], mpsc_stack[i], MPSC_STACK_SIZE,
189 					mpsc_producer,
190 					(void *)(uintptr_t)i, NULL, NULL,
191 					K_PRIO_PREEMPT(5),
192 					K_INHERIT_PERMS, K_NO_WAIT);
193 	}
194 
195 	for (int i = 0; i < MPSC_THREADS_NUM; i++) {
196 		TC_PRINT("joining mpsc thread %d\n", i);
197 		k_thread_join(mpsc_tinfo[i].tid, K_FOREVER);
198 	}
199 }
200 
201 #define THROUGHPUT_ITERS 100000
202 
ZTEST(mpsc,test_mpsc_throughput)203 ZTEST(mpsc, test_mpsc_throughput)
204 {
205 	struct mpsc_node node;
206 	timing_t start_time, end_time;
207 
208 	mpsc_init(&mpsc_q);
209 	timing_init();
210 	timing_start();
211 
212 	start_time = timing_counter_get();
213 
214 	int key = irq_lock();
215 
216 	for (int i = 0; i < THROUGHPUT_ITERS; i++) {
217 		mpsc_push(&mpsc_q, &node);
218 
219 		mpsc_pop(&mpsc_q);
220 	}
221 
222 	irq_unlock(key);
223 
224 	end_time = timing_counter_get();
225 
226 	uint64_t cycles = timing_cycles_get(&start_time, &end_time);
227 	uint64_t ns = timing_cycles_to_ns(cycles);
228 
229 	TC_PRINT("%llu ns for %d iterations, %llu ns per op\n", ns,
230 		 THROUGHPUT_ITERS, ns/THROUGHPUT_ITERS);
231 }
232 
233 ZTEST_SUITE(mpsc, NULL, NULL, NULL, NULL, NULL);
234