1 /*
2  * Copyright (c) 2022 Rodrigo Peixoto <rodrigopex@gmail.com>
3  * SPDX-License-Identifier: Apache-2.0
4  */
5 
6 #include <zephyr/kernel.h>
7 #include <zephyr/init.h>
8 #include <zephyr/sys/iterable_sections.h>
9 #include <zephyr/logging/log.h>
10 #include <zephyr/sys/printk.h>
11 #include <zephyr/net/buf.h>
12 #include <zephyr/zbus/zbus.h>
13 LOG_MODULE_REGISTER(zbus, CONFIG_ZBUS_LOG_LEVEL);
14 
15 #if defined(CONFIG_ZBUS_PRIORITY_BOOST)
16 /* Available only when the priority boost is enabled */
17 static struct k_spinlock _zbus_chan_slock;
18 #endif /* CONFIG_ZBUS_PRIORITY_BOOST */
19 
20 static struct k_spinlock obs_slock;
21 
22 #if defined(CONFIG_ZBUS_MSG_SUBSCRIBER)
23 
24 #if defined(CONFIG_ZBUS_MSG_SUBSCRIBER_BUF_ALLOC_DYNAMIC)
25 
26 NET_BUF_POOL_HEAP_DEFINE(_zbus_msg_subscribers_pool, CONFIG_ZBUS_MSG_SUBSCRIBER_NET_BUF_POOL_SIZE,
27 			 sizeof(struct zbus_channel *), NULL);
28 BUILD_ASSERT(K_HEAP_MEM_POOL_SIZE > 0, "MSG_SUBSCRIBER feature requires heap memory pool.");
29 
_zbus_create_net_buf(struct net_buf_pool * pool,size_t size,k_timeout_t timeout)30 static inline struct net_buf *_zbus_create_net_buf(struct net_buf_pool *pool, size_t size,
31 						   k_timeout_t timeout)
32 {
33 	return net_buf_alloc_len(&_zbus_msg_subscribers_pool, size, timeout);
34 }
35 
36 #else
37 
38 NET_BUF_POOL_FIXED_DEFINE(_zbus_msg_subscribers_pool,
39 			  (CONFIG_ZBUS_MSG_SUBSCRIBER_NET_BUF_POOL_SIZE),
40 			  (CONFIG_ZBUS_MSG_SUBSCRIBER_NET_BUF_STATIC_DATA_SIZE),
41 			  sizeof(struct zbus_channel *), NULL);
42 
_zbus_create_net_buf(struct net_buf_pool * pool,size_t size,k_timeout_t timeout)43 static inline struct net_buf *_zbus_create_net_buf(struct net_buf_pool *pool, size_t size,
44 						   k_timeout_t timeout)
45 {
46 	__ASSERT(size <= CONFIG_ZBUS_MSG_SUBSCRIBER_NET_BUF_STATIC_DATA_SIZE,
47 		 "CONFIG_ZBUS_MSG_SUBSCRIBER_NET_BUF_STATIC_DATA_SIZE must be greater or equal to "
48 		 "%d",
49 		 (int)size);
50 	return net_buf_alloc(&_zbus_msg_subscribers_pool, timeout);
51 }
52 #endif /* CONFIG_ZBUS_MSG_SUBSCRIBER_BUF_ALLOC_DYNAMIC */
53 
54 #endif /* CONFIG_ZBUS_MSG_SUBSCRIBER */
55 
_zbus_init(void)56 int _zbus_init(void)
57 {
58 
59 	const struct zbus_channel *curr = NULL;
60 	const struct zbus_channel *prev = NULL;
61 
62 	STRUCT_SECTION_FOREACH(zbus_channel_observation, observation) {
63 		curr = observation->chan;
64 
65 		if (prev != curr) {
66 			if (prev == NULL) {
67 				curr->data->observers_start_idx = 0;
68 				curr->data->observers_end_idx = 0;
69 			} else {
70 				curr->data->observers_start_idx = prev->data->observers_end_idx;
71 				curr->data->observers_end_idx = prev->data->observers_end_idx;
72 			}
73 			prev = curr;
74 		}
75 
76 		++(curr->data->observers_end_idx);
77 	}
78 	STRUCT_SECTION_FOREACH(zbus_channel, chan) {
79 		k_sem_init(&chan->data->sem, 1, 1);
80 
81 #if defined(CONFIG_ZBUS_RUNTIME_OBSERVERS)
82 		sys_slist_init(&chan->data->observers);
83 #endif /* CONFIG_ZBUS_RUNTIME_OBSERVERS */
84 	}
85 	return 0;
86 }
87 SYS_INIT(_zbus_init, APPLICATION, CONFIG_ZBUS_CHANNELS_SYS_INIT_PRIORITY);
88 
_zbus_notify_observer(const struct zbus_channel * chan,const struct zbus_observer * obs,k_timepoint_t end_time,struct net_buf * buf)89 static inline int _zbus_notify_observer(const struct zbus_channel *chan,
90 					const struct zbus_observer *obs, k_timepoint_t end_time,
91 					struct net_buf *buf)
92 {
93 	switch (obs->type) {
94 	case ZBUS_OBSERVER_LISTENER_TYPE: {
95 		obs->callback(chan);
96 		break;
97 	}
98 	case ZBUS_OBSERVER_SUBSCRIBER_TYPE: {
99 		return k_msgq_put(obs->queue, &chan, sys_timepoint_timeout(end_time));
100 	}
101 #if defined(CONFIG_ZBUS_MSG_SUBSCRIBER)
102 	case ZBUS_OBSERVER_MSG_SUBSCRIBER_TYPE: {
103 		struct net_buf *cloned_buf = net_buf_clone(buf, sys_timepoint_timeout(end_time));
104 
105 		if (cloned_buf == NULL) {
106 			return -ENOMEM;
107 		}
108 		memcpy(net_buf_user_data(cloned_buf), &chan, sizeof(struct zbus_channel *));
109 
110 		net_buf_put(obs->message_fifo, cloned_buf);
111 
112 		break;
113 	}
114 #endif /* CONFIG_ZBUS_MSG_SUBSCRIBER */
115 
116 	default:
117 		_ZBUS_ASSERT(false, "Unreachable");
118 	}
119 	return 0;
120 }
121 
_zbus_vded_exec(const struct zbus_channel * chan,k_timepoint_t end_time)122 static inline int _zbus_vded_exec(const struct zbus_channel *chan, k_timepoint_t end_time)
123 {
124 	int err = 0;
125 	int last_error = 0;
126 	struct net_buf *buf = NULL;
127 
128 	/* Static observer event dispatcher logic */
129 	struct zbus_channel_observation *observation;
130 	struct zbus_channel_observation_mask *observation_mask;
131 
132 #if defined(CONFIG_ZBUS_MSG_SUBSCRIBER)
133 	buf = _zbus_create_net_buf(&_zbus_msg_subscribers_pool, zbus_chan_msg_size(chan),
134 				   sys_timepoint_timeout(end_time));
135 
136 	_ZBUS_ASSERT(buf != NULL, "net_buf zbus_msg_subscribers_pool is "
137 				  "unavailable or heap is full");
138 
139 	net_buf_add_mem(buf, zbus_chan_msg(chan), zbus_chan_msg_size(chan));
140 #endif /* CONFIG_ZBUS_MSG_SUBSCRIBER */
141 
142 	LOG_DBG("Notifing %s's observers. Starting VDED:", _ZBUS_CHAN_NAME(chan));
143 
144 	int __maybe_unused index = 0;
145 
146 	for (int16_t i = chan->data->observers_start_idx, limit = chan->data->observers_end_idx;
147 	     i < limit; ++i) {
148 		STRUCT_SECTION_GET(zbus_channel_observation, i, &observation);
149 		STRUCT_SECTION_GET(zbus_channel_observation_mask, i, &observation_mask);
150 
151 		_ZBUS_ASSERT(observation != NULL, "observation must be not NULL");
152 
153 		const struct zbus_observer *obs = observation->obs;
154 
155 		if (!obs->data->enabled || observation_mask->enabled) {
156 			continue;
157 		}
158 
159 		err = _zbus_notify_observer(chan, obs, end_time, buf);
160 
161 		if (err) {
162 			last_error = err;
163 			LOG_ERR("could not deliver notification to observer %s. Error code %d",
164 				_ZBUS_OBS_NAME(obs), err);
165 			if (err == -ENOMEM) {
166 				if (IS_ENABLED(CONFIG_ZBUS_MSG_SUBSCRIBER)) {
167 					net_buf_unref(buf);
168 				}
169 				return err;
170 			}
171 		}
172 
173 		LOG_DBG(" %d -> %s", index++, _ZBUS_OBS_NAME(obs));
174 	}
175 
176 #if defined(CONFIG_ZBUS_RUNTIME_OBSERVERS)
177 	/* Dynamic observer event dispatcher logic */
178 	struct zbus_observer_node *obs_nd, *tmp;
179 
180 	SYS_SLIST_FOR_EACH_CONTAINER_SAFE(&chan->data->observers, obs_nd, tmp, node) {
181 
182 		const struct zbus_observer *obs = obs_nd->obs;
183 
184 		if (!obs->data->enabled) {
185 			continue;
186 		}
187 
188 		err = _zbus_notify_observer(chan, obs, end_time, buf);
189 
190 		if (err) {
191 			last_error = err;
192 		}
193 	}
194 #endif /* CONFIG_ZBUS_RUNTIME_OBSERVERS */
195 
196 	IF_ENABLED(CONFIG_ZBUS_MSG_SUBSCRIBER, (net_buf_unref(buf);))
197 
198 	return last_error;
199 }
200 
201 #if defined(CONFIG_ZBUS_PRIORITY_BOOST)
202 
chan_update_hop(const struct zbus_channel * chan)203 static inline void chan_update_hop(const struct zbus_channel *chan)
204 {
205 	struct zbus_channel_observation *observation;
206 	struct zbus_channel_observation_mask *observation_mask;
207 
208 	int chan_highest_observer_priority = ZBUS_MIN_THREAD_PRIORITY;
209 
210 	K_SPINLOCK(&_zbus_chan_slock) {
211 		const int limit = chan->data->observers_end_idx;
212 
213 		for (int16_t i = chan->data->observers_start_idx; i < limit; ++i) {
214 			STRUCT_SECTION_GET(zbus_channel_observation, i, &observation);
215 			STRUCT_SECTION_GET(zbus_channel_observation_mask, i, &observation_mask);
216 
217 			__ASSERT(observation != NULL, "observation must be not NULL");
218 
219 			const struct zbus_observer *obs = observation->obs;
220 
221 			if (!obs->data->enabled || observation_mask->enabled) {
222 				continue;
223 			}
224 
225 			if (chan_highest_observer_priority > obs->data->priority) {
226 				chan_highest_observer_priority = obs->data->priority;
227 			}
228 		}
229 		chan->data->highest_observer_priority = chan_highest_observer_priority;
230 	}
231 }
232 
update_all_channels_hop(const struct zbus_observer * obs)233 static inline void update_all_channels_hop(const struct zbus_observer *obs)
234 {
235 	struct zbus_channel_observation *observation;
236 
237 	int count;
238 
239 	STRUCT_SECTION_COUNT(zbus_channel_observation, &count);
240 
241 	for (int16_t i = 0; i < count; ++i) {
242 		STRUCT_SECTION_GET(zbus_channel_observation, i, &observation);
243 
244 		if (obs != observation->obs) {
245 			continue;
246 		}
247 
248 		chan_update_hop(observation->chan);
249 	}
250 }
251 
zbus_obs_attach_to_thread(const struct zbus_observer * obs)252 int zbus_obs_attach_to_thread(const struct zbus_observer *obs)
253 {
254 	_ZBUS_ASSERT(!k_is_in_isr(), "cannot attach to an ISR");
255 	_ZBUS_ASSERT(obs != NULL, "obs is required");
256 
257 	int current_thread_priority = k_thread_priority_get(k_current_get());
258 
259 	K_SPINLOCK(&obs_slock) {
260 		if (obs->data->priority != current_thread_priority) {
261 			obs->data->priority = current_thread_priority;
262 
263 			update_all_channels_hop(obs);
264 		}
265 	}
266 
267 	return 0;
268 }
269 
zbus_obs_detach_from_thread(const struct zbus_observer * obs)270 int zbus_obs_detach_from_thread(const struct zbus_observer *obs)
271 {
272 	_ZBUS_ASSERT(!k_is_in_isr(), "cannot detach from an ISR");
273 	_ZBUS_ASSERT(obs != NULL, "obs is required");
274 
275 	K_SPINLOCK(&obs_slock) {
276 		obs->data->priority = ZBUS_MIN_THREAD_PRIORITY;
277 
278 		update_all_channels_hop(obs);
279 	}
280 
281 	return 0;
282 }
283 
284 #else
285 
update_all_channels_hop(const struct zbus_observer * obs)286 static inline void update_all_channels_hop(const struct zbus_observer *obs)
287 {
288 }
289 
290 #endif /* CONFIG_ZBUS_PRIORITY_BOOST */
291 
chan_lock(const struct zbus_channel * chan,k_timeout_t timeout,int * prio)292 static inline int chan_lock(const struct zbus_channel *chan, k_timeout_t timeout, int *prio)
293 {
294 	bool boosting = false;
295 
296 #if defined(CONFIG_ZBUS_PRIORITY_BOOST)
297 	if (!k_is_in_isr()) {
298 		*prio = k_thread_priority_get(k_current_get());
299 
300 		K_SPINLOCK(&_zbus_chan_slock) {
301 			if (*prio > chan->data->highest_observer_priority) {
302 				int new_prio = chan->data->highest_observer_priority - 1;
303 
304 				new_prio = MAX(new_prio, 0);
305 
306 				/* Elevating priority since the highest_observer_priority is
307 				 * greater than the current thread
308 				 */
309 				k_thread_priority_set(k_current_get(), new_prio);
310 
311 				boosting = true;
312 			}
313 		}
314 	}
315 #endif /* CONFIG_ZBUS_PRIORITY_BOOST */
316 
317 	int err = k_sem_take(&chan->data->sem, timeout);
318 
319 	if (err) {
320 		/* When the priority boost is disabled, this IF will be optimized out. */
321 		if (boosting) {
322 			/* Restoring thread priority since the semaphore is not available */
323 			k_thread_priority_set(k_current_get(), *prio);
324 		}
325 
326 		return err;
327 	}
328 
329 	return 0;
330 }
331 
chan_unlock(const struct zbus_channel * chan,int prio)332 static inline void chan_unlock(const struct zbus_channel *chan, int prio)
333 {
334 	k_sem_give(&chan->data->sem);
335 
336 #if defined(CONFIG_ZBUS_PRIORITY_BOOST)
337 	/* During the unlock phase, with the priority boost enabled, the priority must be
338 	 * restored to the original value in case it was elevated
339 	 */
340 	if (prio < ZBUS_MIN_THREAD_PRIORITY) {
341 		k_thread_priority_set(k_current_get(), prio);
342 	}
343 #endif /* CONFIG_ZBUS_PRIORITY_BOOST */
344 }
345 
zbus_chan_pub(const struct zbus_channel * chan,const void * msg,k_timeout_t timeout)346 int zbus_chan_pub(const struct zbus_channel *chan, const void *msg, k_timeout_t timeout)
347 {
348 	int err;
349 
350 	_ZBUS_ASSERT(chan != NULL, "chan is required");
351 	_ZBUS_ASSERT(msg != NULL, "msg is required");
352 
353 	if (k_is_in_isr()) {
354 		timeout = K_NO_WAIT;
355 	}
356 
357 	k_timepoint_t end_time = sys_timepoint_calc(timeout);
358 
359 	if (chan->validator != NULL && !chan->validator(msg, chan->message_size)) {
360 		return -ENOMSG;
361 	}
362 
363 	int context_priority = ZBUS_MIN_THREAD_PRIORITY;
364 
365 	err = chan_lock(chan, timeout, &context_priority);
366 	if (err) {
367 		return err;
368 	}
369 
370 	memcpy(chan->message, msg, chan->message_size);
371 
372 	err = _zbus_vded_exec(chan, end_time);
373 
374 	chan_unlock(chan, context_priority);
375 
376 	return err;
377 }
378 
zbus_chan_read(const struct zbus_channel * chan,void * msg,k_timeout_t timeout)379 int zbus_chan_read(const struct zbus_channel *chan, void *msg, k_timeout_t timeout)
380 {
381 	_ZBUS_ASSERT(chan != NULL, "chan is required");
382 	_ZBUS_ASSERT(msg != NULL, "msg is required");
383 
384 	if (k_is_in_isr()) {
385 		timeout = K_NO_WAIT;
386 	}
387 
388 	int err = k_sem_take(&chan->data->sem, timeout);
389 	if (err) {
390 		return err;
391 	}
392 
393 	memcpy(msg, chan->message, chan->message_size);
394 
395 	k_sem_give(&chan->data->sem);
396 
397 	return 0;
398 }
399 
zbus_chan_notify(const struct zbus_channel * chan,k_timeout_t timeout)400 int zbus_chan_notify(const struct zbus_channel *chan, k_timeout_t timeout)
401 {
402 	int err;
403 
404 	_ZBUS_ASSERT(chan != NULL, "chan is required");
405 
406 	if (k_is_in_isr()) {
407 		timeout = K_NO_WAIT;
408 	}
409 
410 	k_timepoint_t end_time = sys_timepoint_calc(timeout);
411 
412 	int context_priority = ZBUS_MIN_THREAD_PRIORITY;
413 
414 	err = chan_lock(chan, timeout, &context_priority);
415 	if (err) {
416 		return err;
417 	}
418 
419 	err = _zbus_vded_exec(chan, end_time);
420 
421 	chan_unlock(chan, context_priority);
422 
423 	return err;
424 }
425 
zbus_chan_claim(const struct zbus_channel * chan,k_timeout_t timeout)426 int zbus_chan_claim(const struct zbus_channel *chan, k_timeout_t timeout)
427 {
428 	_ZBUS_ASSERT(chan != NULL, "chan is required");
429 
430 	if (k_is_in_isr()) {
431 		timeout = K_NO_WAIT;
432 	}
433 
434 	int err = k_sem_take(&chan->data->sem, timeout);
435 
436 	if (err) {
437 		return err;
438 	}
439 
440 	return 0;
441 }
442 
zbus_chan_finish(const struct zbus_channel * chan)443 int zbus_chan_finish(const struct zbus_channel *chan)
444 {
445 	_ZBUS_ASSERT(chan != NULL, "chan is required");
446 
447 	k_sem_give(&chan->data->sem);
448 
449 	return 0;
450 }
451 
zbus_sub_wait(const struct zbus_observer * sub,const struct zbus_channel ** chan,k_timeout_t timeout)452 int zbus_sub_wait(const struct zbus_observer *sub, const struct zbus_channel **chan,
453 		  k_timeout_t timeout)
454 {
455 	_ZBUS_ASSERT(!k_is_in_isr(), "zbus_sub_wait cannot be used inside ISRs");
456 	_ZBUS_ASSERT(sub != NULL, "sub is required");
457 	_ZBUS_ASSERT(sub->type == ZBUS_OBSERVER_SUBSCRIBER_TYPE, "sub must be a SUBSCRIBER");
458 	_ZBUS_ASSERT(sub->queue != NULL, "sub queue is required");
459 	_ZBUS_ASSERT(chan != NULL, "chan is required");
460 
461 	return k_msgq_get(sub->queue, chan, timeout);
462 }
463 
464 #if defined(CONFIG_ZBUS_MSG_SUBSCRIBER)
465 
zbus_sub_wait_msg(const struct zbus_observer * sub,const struct zbus_channel ** chan,void * msg,k_timeout_t timeout)466 int zbus_sub_wait_msg(const struct zbus_observer *sub, const struct zbus_channel **chan, void *msg,
467 		      k_timeout_t timeout)
468 {
469 	_ZBUS_ASSERT(!k_is_in_isr(), "zbus_sub_wait_msg cannot be used inside ISRs");
470 	_ZBUS_ASSERT(sub != NULL, "sub is required");
471 	_ZBUS_ASSERT(sub->type == ZBUS_OBSERVER_MSG_SUBSCRIBER_TYPE,
472 		     "sub must be a MSG_SUBSCRIBER");
473 	_ZBUS_ASSERT(sub->message_fifo != NULL, "sub message_fifo is required");
474 	_ZBUS_ASSERT(chan != NULL, "chan is required");
475 	_ZBUS_ASSERT(msg != NULL, "msg is required");
476 
477 	struct net_buf *buf = net_buf_get(sub->message_fifo, timeout);
478 
479 	if (buf == NULL) {
480 		return -ENOMSG;
481 	}
482 
483 	*chan = *((struct zbus_channel **)net_buf_user_data(buf));
484 
485 	memcpy(msg, net_buf_remove_mem(buf, zbus_chan_msg_size(*chan)), zbus_chan_msg_size(*chan));
486 
487 	net_buf_unref(buf);
488 
489 	return 0;
490 }
491 
492 #endif /* CONFIG_ZBUS_MSG_SUBSCRIBER */
493 
zbus_obs_set_chan_notification_mask(const struct zbus_observer * obs,const struct zbus_channel * chan,bool masked)494 int zbus_obs_set_chan_notification_mask(const struct zbus_observer *obs,
495 					const struct zbus_channel *chan, bool masked)
496 {
497 	_ZBUS_ASSERT(obs != NULL, "obs is required");
498 	_ZBUS_ASSERT(chan != NULL, "chan is required");
499 
500 	int err = -ESRCH;
501 
502 	struct zbus_channel_observation *observation;
503 	struct zbus_channel_observation_mask *observation_mask;
504 
505 	K_SPINLOCK(&obs_slock) {
506 		for (int16_t i = chan->data->observers_start_idx,
507 			     limit = chan->data->observers_end_idx;
508 		     i < limit; ++i) {
509 			STRUCT_SECTION_GET(zbus_channel_observation, i, &observation);
510 			STRUCT_SECTION_GET(zbus_channel_observation_mask, i, &observation_mask);
511 
512 			__ASSERT(observation != NULL, "observation must be not NULL");
513 
514 			if (observation->obs == obs) {
515 				if (observation_mask->enabled != masked) {
516 					observation_mask->enabled = masked;
517 
518 					update_all_channels_hop(obs);
519 				}
520 
521 				err = 0;
522 
523 				K_SPINLOCK_BREAK;
524 			}
525 		}
526 	}
527 
528 	return err;
529 }
530 
zbus_obs_is_chan_notification_masked(const struct zbus_observer * obs,const struct zbus_channel * chan,bool * masked)531 int zbus_obs_is_chan_notification_masked(const struct zbus_observer *obs,
532 					 const struct zbus_channel *chan, bool *masked)
533 {
534 	_ZBUS_ASSERT(obs != NULL, "obs is required");
535 	_ZBUS_ASSERT(chan != NULL, "chan is required");
536 
537 	int err = -ESRCH;
538 
539 	struct zbus_channel_observation *observation;
540 	struct zbus_channel_observation_mask *observation_mask;
541 
542 	K_SPINLOCK(&obs_slock) {
543 		const int limit = chan->data->observers_end_idx;
544 
545 		for (int16_t i = chan->data->observers_start_idx; i < limit; ++i) {
546 			STRUCT_SECTION_GET(zbus_channel_observation, i, &observation);
547 			STRUCT_SECTION_GET(zbus_channel_observation_mask, i, &observation_mask);
548 
549 			__ASSERT(observation != NULL, "observation must be not NULL");
550 
551 			if (observation->obs == obs) {
552 				*masked = observation_mask->enabled;
553 
554 				err = 0;
555 
556 				K_SPINLOCK_BREAK;
557 			}
558 		}
559 	}
560 
561 	return err;
562 }
563 
zbus_obs_set_enable(struct zbus_observer * obs,bool enabled)564 int zbus_obs_set_enable(struct zbus_observer *obs, bool enabled)
565 {
566 	_ZBUS_ASSERT(obs != NULL, "obs is required");
567 
568 	K_SPINLOCK(&obs_slock) {
569 		if (obs->data->enabled != enabled) {
570 			obs->data->enabled = enabled;
571 
572 			update_all_channels_hop(obs);
573 		}
574 	}
575 
576 	return 0;
577 }
578