1 /*
2  * Copyright (c) 2018 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 /** @file mqtt.c
8  *
9  * @brief MQTT Client API Implementation.
10  */
11 
12 #include <zephyr/logging/log.h>
13 LOG_MODULE_REGISTER(net_mqtt, CONFIG_MQTT_LOG_LEVEL);
14 
15 #include <zephyr/net/mqtt.h>
16 
17 #include "mqtt_transport.h"
18 #include "mqtt_internal.h"
19 #include "mqtt_os.h"
20 
client_reset(struct mqtt_client * client)21 static void client_reset(struct mqtt_client *client)
22 {
23 	MQTT_STATE_INIT(client);
24 
25 	client->internal.last_activity = 0U;
26 	client->internal.rx_buf_datalen = 0U;
27 	client->internal.remaining_payload = 0U;
28 }
29 
30 /** @brief Initialize tx buffer. */
tx_buf_init(struct mqtt_client * client,struct buf_ctx * buf)31 static void tx_buf_init(struct mqtt_client *client, struct buf_ctx *buf)
32 {
33 	memset(client->tx_buf, 0, client->tx_buf_size);
34 	buf->cur = client->tx_buf;
35 	buf->end = client->tx_buf + client->tx_buf_size;
36 }
37 
event_notify(struct mqtt_client * client,const struct mqtt_evt * evt)38 void event_notify(struct mqtt_client *client, const struct mqtt_evt *evt)
39 {
40 	if (client->evt_cb != NULL) {
41 		mqtt_mutex_unlock(client);
42 
43 		client->evt_cb(client, evt);
44 
45 		mqtt_mutex_lock(client);
46 	}
47 }
48 
client_disconnect(struct mqtt_client * client,int result,bool notify)49 static void client_disconnect(struct mqtt_client *client, int result,
50 			      bool notify)
51 {
52 	int err_code;
53 
54 	err_code = mqtt_transport_disconnect(client);
55 	if (err_code < 0) {
56 		NET_ERR("Failed to disconnect transport!");
57 	}
58 
59 	/* Reset internal state. */
60 	client_reset(client);
61 
62 	if (notify) {
63 		struct mqtt_evt evt = {
64 			.type = MQTT_EVT_DISCONNECT,
65 			.result = result,
66 		};
67 
68 		/* Notify application. */
69 		event_notify(client, &evt);
70 	}
71 }
72 
client_connect(struct mqtt_client * client)73 static int client_connect(struct mqtt_client *client)
74 {
75 	int err_code;
76 	struct buf_ctx packet;
77 
78 	err_code = mqtt_transport_connect(client);
79 	if (err_code < 0) {
80 		return err_code;
81 	}
82 
83 	tx_buf_init(client, &packet);
84 	MQTT_SET_STATE(client, MQTT_STATE_TCP_CONNECTED);
85 
86 	err_code = connect_request_encode(client, &packet);
87 	if (err_code < 0) {
88 		goto error;
89 	}
90 
91 	/* Send MQTT identification message to broker. */
92 	err_code = mqtt_transport_write(client, packet.cur,
93 					packet.end - packet.cur);
94 	if (err_code < 0) {
95 		goto error;
96 	}
97 
98 	client->internal.last_activity = mqtt_sys_tick_in_ms_get();
99 
100 	/* Reset the unanswered ping count for a new connection */
101 	client->unacked_ping = 0;
102 
103 	NET_INFO("Connect completed");
104 
105 	return 0;
106 
107 error:
108 	client_disconnect(client, err_code, false);
109 	return err_code;
110 }
111 
client_read(struct mqtt_client * client)112 static int client_read(struct mqtt_client *client)
113 {
114 	int err_code;
115 
116 	if (client->internal.remaining_payload > 0) {
117 		return -EBUSY;
118 	}
119 
120 	err_code = mqtt_handle_rx(client);
121 	if (err_code < 0) {
122 		client_disconnect(client, err_code, true);
123 	}
124 
125 	return err_code;
126 }
127 
client_write(struct mqtt_client * client,const uint8_t * data,uint32_t datalen)128 static int client_write(struct mqtt_client *client, const uint8_t *data,
129 			uint32_t datalen)
130 {
131 	int err_code;
132 
133 	NET_DBG("[%p]: Transport writing %d bytes.", client, datalen);
134 
135 	err_code = mqtt_transport_write(client, data, datalen);
136 	if (err_code < 0) {
137 		NET_ERR("Transport write failed, err_code = %d, "
138 			 "closing connection", err_code);
139 		client_disconnect(client, err_code, true);
140 		return err_code;
141 	}
142 
143 	NET_DBG("[%p]: Transport write complete.", client);
144 	client->internal.last_activity = mqtt_sys_tick_in_ms_get();
145 
146 	return 0;
147 }
148 
client_write_msg(struct mqtt_client * client,const struct msghdr * message)149 static int client_write_msg(struct mqtt_client *client,
150 			    const struct msghdr *message)
151 {
152 	int err_code;
153 
154 	NET_DBG("[%p]: Transport writing message.", client);
155 
156 	err_code = mqtt_transport_write_msg(client, message);
157 	if (err_code < 0) {
158 		NET_ERR("Transport write failed, err_code = %d, "
159 			 "closing connection", err_code);
160 		client_disconnect(client, err_code, true);
161 		return err_code;
162 	}
163 
164 	NET_DBG("[%p]: Transport write complete.", client);
165 	client->internal.last_activity = mqtt_sys_tick_in_ms_get();
166 
167 	return 0;
168 }
169 
mqtt_client_init(struct mqtt_client * client)170 void mqtt_client_init(struct mqtt_client *client)
171 {
172 	NULL_PARAM_CHECK_VOID(client);
173 
174 	memset(client, 0, sizeof(*client));
175 
176 	MQTT_STATE_INIT(client);
177 	mqtt_mutex_init(client);
178 
179 	client->protocol_version = MQTT_VERSION_3_1_1;
180 	client->clean_session = MQTT_CLEAN_SESSION;
181 	client->keepalive = MQTT_KEEPALIVE;
182 }
183 
184 #if defined(CONFIG_SOCKS)
mqtt_client_set_proxy(struct mqtt_client * client,struct sockaddr * proxy_addr,socklen_t addrlen)185 int mqtt_client_set_proxy(struct mqtt_client *client,
186 			  struct sockaddr *proxy_addr,
187 			  socklen_t addrlen)
188 {
189 	if (IS_ENABLED(CONFIG_SOCKS)) {
190 		if (!client || !proxy_addr) {
191 			return -EINVAL;
192 		}
193 
194 		client->transport.proxy.addrlen = addrlen;
195 		memcpy(&client->transport.proxy.addr, proxy_addr, addrlen);
196 
197 		return 0;
198 	}
199 
200 	return -ENOTSUP;
201 }
202 #endif
203 
mqtt_connect(struct mqtt_client * client)204 int mqtt_connect(struct mqtt_client *client)
205 {
206 	int err_code;
207 
208 	NULL_PARAM_CHECK(client);
209 	NULL_PARAM_CHECK(client->client_id.utf8);
210 
211 	mqtt_mutex_lock(client);
212 
213 	if ((client->tx_buf == NULL) || (client->rx_buf == NULL)) {
214 		err_code = -ENOMEM;
215 		goto error;
216 	}
217 
218 	err_code = client_connect(client);
219 
220 error:
221 	if (err_code < 0) {
222 		client_reset(client);
223 	}
224 
225 	mqtt_mutex_unlock(client);
226 
227 	return err_code;
228 }
229 
verify_tx_state(const struct mqtt_client * client)230 static int verify_tx_state(const struct mqtt_client *client)
231 {
232 	if (!MQTT_HAS_STATE(client, MQTT_STATE_CONNECTED)) {
233 		return -ENOTCONN;
234 	}
235 
236 	return 0;
237 }
238 
mqtt_publish(struct mqtt_client * client,const struct mqtt_publish_param * param)239 int mqtt_publish(struct mqtt_client *client,
240 		 const struct mqtt_publish_param *param)
241 {
242 	int err_code;
243 	struct buf_ctx packet;
244 	struct iovec io_vector[2];
245 	struct msghdr msg;
246 
247 	NULL_PARAM_CHECK(client);
248 	NULL_PARAM_CHECK(param);
249 
250 	NET_DBG("[CID %p]:[State 0x%02x]: >> Topic size 0x%08x, "
251 		 "Data size 0x%08x", client, client->internal.state,
252 		 param->message.topic.topic.size,
253 		 param->message.payload.len);
254 
255 	mqtt_mutex_lock(client);
256 
257 	tx_buf_init(client, &packet);
258 
259 	err_code = verify_tx_state(client);
260 	if (err_code < 0) {
261 		goto error;
262 	}
263 
264 	err_code = publish_encode(param, &packet);
265 	if (err_code < 0) {
266 		goto error;
267 	}
268 
269 	io_vector[0].iov_base = packet.cur;
270 	io_vector[0].iov_len = packet.end - packet.cur;
271 	io_vector[1].iov_base = param->message.payload.data;
272 	io_vector[1].iov_len = param->message.payload.len;
273 
274 	memset(&msg, 0, sizeof(msg));
275 
276 	msg.msg_iov = io_vector;
277 	msg.msg_iovlen = ARRAY_SIZE(io_vector);
278 
279 	err_code = client_write_msg(client, &msg);
280 
281 error:
282 	NET_DBG("[CID %p]:[State 0x%02x]: << result 0x%08x",
283 			 client, client->internal.state, err_code);
284 
285 	mqtt_mutex_unlock(client);
286 
287 	return err_code;
288 }
289 
mqtt_publish_qos1_ack(struct mqtt_client * client,const struct mqtt_puback_param * param)290 int mqtt_publish_qos1_ack(struct mqtt_client *client,
291 			  const struct mqtt_puback_param *param)
292 {
293 	int err_code;
294 	struct buf_ctx packet;
295 
296 	NULL_PARAM_CHECK(client);
297 	NULL_PARAM_CHECK(param);
298 
299 	NET_DBG("[CID %p]:[State 0x%02x]: >> Message id 0x%04x",
300 		 client, client->internal.state, param->message_id);
301 
302 	mqtt_mutex_lock(client);
303 
304 	tx_buf_init(client, &packet);
305 
306 	err_code = verify_tx_state(client);
307 	if (err_code < 0) {
308 		goto error;
309 	}
310 
311 	err_code = publish_ack_encode(param, &packet);
312 	if (err_code < 0) {
313 		goto error;
314 	}
315 
316 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
317 
318 error:
319 	NET_DBG("[CID %p]:[State 0x%02x]: << result 0x%08x",
320 		 client, client->internal.state, err_code);
321 
322 	mqtt_mutex_unlock(client);
323 
324 	return err_code;
325 }
326 
mqtt_publish_qos2_receive(struct mqtt_client * client,const struct mqtt_pubrec_param * param)327 int mqtt_publish_qos2_receive(struct mqtt_client *client,
328 			      const struct mqtt_pubrec_param *param)
329 {
330 	int err_code;
331 	struct buf_ctx packet;
332 
333 	NULL_PARAM_CHECK(client);
334 	NULL_PARAM_CHECK(param);
335 
336 	NET_DBG("[CID %p]:[State 0x%02x]: >> Message id 0x%04x",
337 		 client, client->internal.state, param->message_id);
338 
339 	mqtt_mutex_lock(client);
340 
341 	tx_buf_init(client, &packet);
342 
343 	err_code = verify_tx_state(client);
344 	if (err_code < 0) {
345 		goto error;
346 	}
347 
348 	err_code = publish_receive_encode(param, &packet);
349 	if (err_code < 0) {
350 		goto error;
351 	}
352 
353 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
354 
355 error:
356 	NET_DBG("[CID %p]:[State 0x%02x]: << result 0x%08x",
357 		 client, client->internal.state, err_code);
358 
359 	mqtt_mutex_unlock(client);
360 
361 	return err_code;
362 }
363 
mqtt_publish_qos2_release(struct mqtt_client * client,const struct mqtt_pubrel_param * param)364 int mqtt_publish_qos2_release(struct mqtt_client *client,
365 			      const struct mqtt_pubrel_param *param)
366 {
367 	int err_code;
368 	struct buf_ctx packet;
369 
370 	NULL_PARAM_CHECK(client);
371 	NULL_PARAM_CHECK(param);
372 
373 	NET_DBG("[CID %p]:[State 0x%02x]: >> Message id 0x%04x",
374 		 client, client->internal.state, param->message_id);
375 
376 	mqtt_mutex_lock(client);
377 
378 	tx_buf_init(client, &packet);
379 
380 	err_code = verify_tx_state(client);
381 	if (err_code < 0) {
382 		goto error;
383 	}
384 
385 	err_code = publish_release_encode(param, &packet);
386 	if (err_code < 0) {
387 		goto error;
388 	}
389 
390 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
391 
392 error:
393 	NET_DBG("[CID %p]:[State 0x%02x]: << result 0x%08x",
394 		 client, client->internal.state, err_code);
395 
396 	mqtt_mutex_unlock(client);
397 
398 	return err_code;
399 }
400 
mqtt_publish_qos2_complete(struct mqtt_client * client,const struct mqtt_pubcomp_param * param)401 int mqtt_publish_qos2_complete(struct mqtt_client *client,
402 			       const struct mqtt_pubcomp_param *param)
403 {
404 	int err_code;
405 	struct buf_ctx packet;
406 
407 	NULL_PARAM_CHECK(client);
408 	NULL_PARAM_CHECK(param);
409 
410 	NET_DBG("[CID %p]:[State 0x%02x]: >> Message id 0x%04x",
411 		 client, client->internal.state, param->message_id);
412 
413 	mqtt_mutex_lock(client);
414 
415 	tx_buf_init(client, &packet);
416 
417 	err_code = verify_tx_state(client);
418 	if (err_code < 0) {
419 		goto error;
420 	}
421 
422 	err_code = publish_complete_encode(param, &packet);
423 	if (err_code < 0) {
424 		goto error;
425 	}
426 
427 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
428 	if (err_code < 0) {
429 		goto error;
430 	}
431 
432 error:
433 	NET_DBG("[CID %p]:[State 0x%02x]: << result 0x%08x",
434 		 client, client->internal.state, err_code);
435 
436 	mqtt_mutex_unlock(client);
437 
438 	return err_code;
439 }
440 
mqtt_disconnect(struct mqtt_client * client)441 int mqtt_disconnect(struct mqtt_client *client)
442 {
443 	int err_code;
444 	struct buf_ctx packet;
445 
446 	NULL_PARAM_CHECK(client);
447 
448 	mqtt_mutex_lock(client);
449 
450 	tx_buf_init(client, &packet);
451 
452 	err_code = verify_tx_state(client);
453 	if (err_code < 0) {
454 		goto error;
455 	}
456 
457 	err_code = disconnect_encode(&packet);
458 	if (err_code < 0) {
459 		goto error;
460 	}
461 
462 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
463 	if (err_code < 0) {
464 		goto error;
465 	}
466 
467 	client_disconnect(client, 0, true);
468 
469 error:
470 	mqtt_mutex_unlock(client);
471 
472 	return err_code;
473 }
474 
mqtt_subscribe(struct mqtt_client * client,const struct mqtt_subscription_list * param)475 int mqtt_subscribe(struct mqtt_client *client,
476 		   const struct mqtt_subscription_list *param)
477 {
478 	int err_code;
479 	struct buf_ctx packet;
480 
481 	NULL_PARAM_CHECK(client);
482 	NULL_PARAM_CHECK(param);
483 
484 	NET_DBG("[CID %p]:[State 0x%02x]: >> message id 0x%04x "
485 		 "topic count 0x%04x", client, client->internal.state,
486 		 param->message_id, param->list_count);
487 
488 	mqtt_mutex_lock(client);
489 
490 	tx_buf_init(client, &packet);
491 
492 	err_code = verify_tx_state(client);
493 	if (err_code < 0) {
494 		goto error;
495 	}
496 
497 	err_code = subscribe_encode(param, &packet);
498 	if (err_code < 0) {
499 		goto error;
500 	}
501 
502 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
503 
504 error:
505 	NET_DBG("[CID %p]:[State 0x%02x]: << result 0x%08x",
506 		 client, client->internal.state, err_code);
507 
508 	mqtt_mutex_unlock(client);
509 
510 	return err_code;
511 }
512 
mqtt_unsubscribe(struct mqtt_client * client,const struct mqtt_subscription_list * param)513 int mqtt_unsubscribe(struct mqtt_client *client,
514 		     const struct mqtt_subscription_list *param)
515 {
516 	int err_code;
517 	struct buf_ctx packet;
518 
519 	NULL_PARAM_CHECK(client);
520 	NULL_PARAM_CHECK(param);
521 
522 	mqtt_mutex_lock(client);
523 
524 	tx_buf_init(client, &packet);
525 
526 	err_code = verify_tx_state(client);
527 	if (err_code < 0) {
528 		goto error;
529 	}
530 
531 	err_code = unsubscribe_encode(param, &packet);
532 	if (err_code < 0) {
533 		goto error;
534 	}
535 
536 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
537 
538 error:
539 	mqtt_mutex_unlock(client);
540 
541 	return err_code;
542 }
543 
mqtt_ping(struct mqtt_client * client)544 int mqtt_ping(struct mqtt_client *client)
545 {
546 	int err_code;
547 	struct buf_ctx packet;
548 
549 	NULL_PARAM_CHECK(client);
550 
551 	mqtt_mutex_lock(client);
552 
553 	tx_buf_init(client, &packet);
554 
555 	err_code = verify_tx_state(client);
556 	if (err_code < 0) {
557 		goto error;
558 	}
559 
560 	err_code = ping_request_encode(&packet);
561 	if (err_code < 0) {
562 		goto error;
563 	}
564 
565 	err_code = client_write(client, packet.cur, packet.end - packet.cur);
566 
567 	if (client->unacked_ping >= INT8_MAX) {
568 		NET_WARN("PING count overflow!");
569 	} else {
570 		client->unacked_ping++;
571 	}
572 
573 error:
574 	mqtt_mutex_unlock(client);
575 
576 	return err_code;
577 }
578 
mqtt_abort(struct mqtt_client * client)579 int mqtt_abort(struct mqtt_client *client)
580 {
581 	NULL_PARAM_CHECK(client);
582 
583 	mqtt_mutex_lock(client);
584 
585 	if (client->internal.state != MQTT_STATE_IDLE) {
586 		client_disconnect(client, -ECONNABORTED, true);
587 	}
588 
589 	mqtt_mutex_unlock(client);
590 
591 	return 0;
592 }
593 
mqtt_live(struct mqtt_client * client)594 int mqtt_live(struct mqtt_client *client)
595 {
596 	int err_code = 0;
597 	uint32_t elapsed_time;
598 	bool ping_sent = false;
599 
600 	NULL_PARAM_CHECK(client);
601 
602 	mqtt_mutex_lock(client);
603 
604 	elapsed_time = mqtt_elapsed_time_in_ms_get(
605 				client->internal.last_activity);
606 	if ((client->keepalive > 0) &&
607 	    (elapsed_time >= (client->keepalive * 1000))) {
608 		err_code = mqtt_ping(client);
609 		ping_sent = true;
610 	}
611 
612 	mqtt_mutex_unlock(client);
613 
614 	if (ping_sent) {
615 		return err_code;
616 	} else {
617 		return -EAGAIN;
618 	}
619 }
620 
mqtt_keepalive_time_left(const struct mqtt_client * client)621 int mqtt_keepalive_time_left(const struct mqtt_client *client)
622 {
623 	uint32_t elapsed_time = mqtt_elapsed_time_in_ms_get(
624 					client->internal.last_activity);
625 	uint32_t keepalive_ms = 1000U * client->keepalive;
626 
627 	if (client->keepalive == 0) {
628 		/* Keep alive not enabled. */
629 		return -1;
630 	}
631 
632 	if (keepalive_ms <= elapsed_time) {
633 		return 0;
634 	}
635 
636 	return keepalive_ms - elapsed_time;
637 }
638 
mqtt_input(struct mqtt_client * client)639 int mqtt_input(struct mqtt_client *client)
640 {
641 	int err_code = 0;
642 
643 	NULL_PARAM_CHECK(client);
644 
645 	mqtt_mutex_lock(client);
646 
647 	NET_DBG("state:0x%08x", client->internal.state);
648 
649 	if (MQTT_HAS_STATE(client, MQTT_STATE_TCP_CONNECTED)) {
650 		err_code = client_read(client);
651 	} else {
652 		err_code = -ENOTCONN;
653 	}
654 
655 	mqtt_mutex_unlock(client);
656 
657 	return err_code;
658 }
659 
read_publish_payload(struct mqtt_client * client,void * buffer,size_t length,bool shall_block)660 static int read_publish_payload(struct mqtt_client *client, void *buffer,
661 				size_t length, bool shall_block)
662 {
663 	int ret;
664 
665 	NULL_PARAM_CHECK(client);
666 
667 	mqtt_mutex_lock(client);
668 
669 	if (client->internal.remaining_payload == 0U) {
670 		ret = 0;
671 		goto exit;
672 	}
673 
674 	if (client->internal.remaining_payload < length) {
675 		length = client->internal.remaining_payload;
676 	}
677 
678 	ret = mqtt_transport_read(client, buffer, length, shall_block);
679 	if (!shall_block && ret == -EAGAIN) {
680 		goto exit;
681 	}
682 
683 	if (ret <= 0) {
684 		if (ret == 0) {
685 			ret = -ENOTCONN;
686 		}
687 
688 		client_disconnect(client, ret, true);
689 		goto exit;
690 	}
691 
692 	client->internal.remaining_payload -= ret;
693 
694 exit:
695 	mqtt_mutex_unlock(client);
696 
697 	return ret;
698 }
699 
mqtt_read_publish_payload(struct mqtt_client * client,void * buffer,size_t length)700 int mqtt_read_publish_payload(struct mqtt_client *client, void *buffer,
701 			      size_t length)
702 {
703 	return read_publish_payload(client, buffer, length, false);
704 }
705 
mqtt_read_publish_payload_blocking(struct mqtt_client * client,void * buffer,size_t length)706 int mqtt_read_publish_payload_blocking(struct mqtt_client *client, void *buffer,
707 				       size_t length)
708 {
709 	return read_publish_payload(client, buffer, length, true);
710 }
711 
mqtt_readall_publish_payload(struct mqtt_client * client,uint8_t * buffer,size_t length)712 int mqtt_readall_publish_payload(struct mqtt_client *client, uint8_t *buffer,
713 				 size_t length)
714 {
715 	uint8_t *end = buffer + length;
716 
717 	while (buffer < end) {
718 		int ret = mqtt_read_publish_payload_blocking(client, buffer,
719 							     end - buffer);
720 
721 		if (ret < 0) {
722 			return ret;
723 		} else if (ret == 0) {
724 			return -EIO;
725 		}
726 
727 		buffer += ret;
728 	}
729 
730 	return 0;
731 }
732