1 /*
2  * Copyright (c) 2024 Analog Devices, Inc.
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(app_mqtt, LOG_LEVEL_DBG);
9 
10 #include <zephyr/kernel.h>
11 #include <zephyr/net/socket.h>
12 #include <zephyr/net/mqtt.h>
13 #include <zephyr/data/json.h>
14 #include <zephyr/random/random.h>
15 
16 #include "mqtt_client.h"
17 #include "device.h"
18 
19 /* Buffers for MQTT client */
20 static uint8_t rx_buffer[CONFIG_NET_SAMPLE_MQTT_PAYLOAD_SIZE];
21 static uint8_t tx_buffer[CONFIG_NET_SAMPLE_MQTT_PAYLOAD_SIZE];
22 
23 /* MQTT payload buffer */
24 static uint8_t payload_buf[CONFIG_NET_SAMPLE_MQTT_PAYLOAD_SIZE];
25 
26 /* MQTT broker details */
27 static struct sockaddr_storage broker;
28 
29 /* Socket descriptor */
30 static struct zsock_pollfd fds[1];
31 static int nfds;
32 
33 /* JSON payload format */
34 static const struct json_obj_descr sensor_sample_descr[] = {
35 	JSON_OBJ_DESCR_PRIM(struct sensor_sample, unit, JSON_TOK_STRING),
36 	JSON_OBJ_DESCR_PRIM(struct sensor_sample, value, JSON_TOK_NUMBER),
37 };
38 
39 /* MQTT connectivity status flag */
40 bool mqtt_connected;
41 
42 /* MQTT client ID buffer */
43 static uint8_t client_id[50];
44 
45 #if defined(CONFIG_MQTT_LIB_TLS)
46 #include "tls_config/cert.h"
47 
48 /* This should match the CN field in the server's CA cert */
49 #define TLS_SNI_HOSTNAME	CONFIG_NET_SAMPLE_MQTT_BROKER_HOSTNAME
50 #define APP_CA_CERT_TAG		1
51 
52 static const sec_tag_t m_sec_tags[] = {
53 	APP_CA_CERT_TAG,
54 };
55 
56 /** Register CA certificate for TLS */
tls_init(void)57 static int tls_init(void)
58 {
59 	int rc;
60 
61 	rc = tls_credential_add(APP_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE,
62 				 ca_certificate, sizeof(ca_certificate));
63 	if (rc < 0) {
64 		LOG_ERR("Failed to register public certificate: %d", rc);
65 		return rc;
66 	}
67 
68 	return rc;
69 }
70 #endif
71 
prepare_fds(struct mqtt_client * client)72 static void prepare_fds(struct mqtt_client *client)
73 {
74 	if (client->transport.type == MQTT_TRANSPORT_NON_SECURE) {
75 		fds[0].fd = client->transport.tcp.sock;
76 	}
77 #if defined(CONFIG_MQTT_LIB_TLS)
78 	else if (client->transport.type == MQTT_TRANSPORT_SECURE) {
79 		fds[0].fd = client->transport.tls.sock;
80 	}
81 #endif
82 
83 	fds[0].events = ZSOCK_POLLIN;
84 	nfds = 1;
85 }
86 
clear_fds(void)87 static void clear_fds(void)
88 {
89 	nfds = 0;
90 }
91 
92 /** Initialise the MQTT client ID as the board name with random hex postfix */
init_mqtt_client_id(void)93 static void init_mqtt_client_id(void)
94 {
95 	snprintk(client_id, sizeof(client_id), CONFIG_BOARD"_%x", (uint8_t)sys_rand32_get());
96 }
97 
on_mqtt_connect(void)98 static inline void on_mqtt_connect(void)
99 {
100 	mqtt_connected = true;
101 	device_write_led(LED_NET, LED_ON);
102 	LOG_INF("Connected to MQTT broker!");
103 	LOG_INF("Hostname: %s", CONFIG_NET_SAMPLE_MQTT_BROKER_HOSTNAME);
104 	LOG_INF("Client ID: %s", client_id);
105 	LOG_INF("Port: %s", CONFIG_NET_SAMPLE_MQTT_BROKER_PORT);
106 	LOG_INF("TLS: %s",
107 		IS_ENABLED(CONFIG_MQTT_LIB_TLS) ? "Enabled" : "Disabled");
108 }
109 
on_mqtt_disconnect(void)110 static inline void on_mqtt_disconnect(void)
111 {
112 	mqtt_connected = false;
113 	clear_fds();
114 	device_write_led(LED_NET, LED_OFF);
115 	LOG_INF("Disconnected from MQTT broker");
116 }
117 
118 /** Called when an MQTT payload is received.
119  *  Reads the payload and calls the commands
120  *  handler if a payloads is received on the
121  *  command topic
122  */
on_mqtt_publish(struct mqtt_client * const client,const struct mqtt_evt * evt)123 static void on_mqtt_publish(struct mqtt_client *const client, const struct mqtt_evt *evt)
124 {
125 	int rc;
126 	uint8_t payload[CONFIG_NET_SAMPLE_MQTT_PAYLOAD_SIZE];
127 
128 	rc = mqtt_read_publish_payload(client, payload,
129 					CONFIG_NET_SAMPLE_MQTT_PAYLOAD_SIZE);
130 	if (rc < 0) {
131 		LOG_ERR("Failed to read received MQTT payload [%d]", rc);
132 		return;
133 	}
134 	/* Place null terminator at end of payload buffer */
135 	payload[rc] = '\0';
136 
137 	LOG_INF("MQTT payload received!");
138 	LOG_INF("topic: '%s', payload: %s",
139 		evt->param.publish.message.topic.topic.utf8, payload);
140 
141 	/* If the topic is a command, call the command handler  */
142 	if (strcmp(evt->param.publish.message.topic.topic.utf8,
143 			CONFIG_NET_SAMPLE_MQTT_SUB_TOPIC_CMD) == 0) {
144 		device_command_handler(payload);
145 	}
146 }
147 
148 /** Handler for asynchronous MQTT events */
mqtt_event_handler(struct mqtt_client * const client,const struct mqtt_evt * evt)149 static void mqtt_event_handler(struct mqtt_client *const client, const struct mqtt_evt *evt)
150 {
151 	switch (evt->type) {
152 	case MQTT_EVT_CONNACK:
153 		if (evt->result != 0) {
154 			LOG_ERR("MQTT Event Connect failed [%d]", evt->result);
155 			break;
156 		}
157 		on_mqtt_connect();
158 		break;
159 
160 	case MQTT_EVT_DISCONNECT:
161 		on_mqtt_disconnect();
162 		break;
163 
164 	case MQTT_EVT_PINGRESP:
165 		LOG_INF("PINGRESP packet");
166 		break;
167 
168 	case MQTT_EVT_PUBACK:
169 		if (evt->result != 0) {
170 			LOG_ERR("MQTT PUBACK error [%d]", evt->result);
171 			break;
172 		}
173 
174 		LOG_INF("PUBACK packet ID: %u", evt->param.puback.message_id);
175 		break;
176 
177 	case MQTT_EVT_PUBREC:
178 		if (evt->result != 0) {
179 			LOG_ERR("MQTT PUBREC error [%d]", evt->result);
180 			break;
181 		}
182 
183 		LOG_INF("PUBREC packet ID: %u", evt->param.pubrec.message_id);
184 
185 		const struct mqtt_pubrel_param rel_param = {
186 			.message_id = evt->param.pubrec.message_id
187 		};
188 
189 		mqtt_publish_qos2_release(client, &rel_param);
190 		break;
191 
192 	case MQTT_EVT_PUBREL:
193 		if (evt->result != 0) {
194 			LOG_ERR("MQTT PUBREL error [%d]", evt->result);
195 			break;
196 		}
197 
198 		LOG_INF("PUBREL packet ID: %u", evt->param.pubrel.message_id);
199 
200 		const struct mqtt_pubcomp_param rec_param = {
201 			.message_id = evt->param.pubrel.message_id
202 		};
203 
204 		mqtt_publish_qos2_complete(client, &rec_param);
205 		break;
206 
207 	case MQTT_EVT_PUBCOMP:
208 		if (evt->result != 0) {
209 			LOG_ERR("MQTT PUBCOMP error %d", evt->result);
210 			break;
211 		}
212 
213 		LOG_INF("PUBCOMP packet ID: %u", evt->param.pubcomp.message_id);
214 		break;
215 
216 	case MQTT_EVT_SUBACK:
217 		if (evt->result == 0x80) {
218 			LOG_ERR("MQTT SUBACK error [%d]", evt->result);
219 			break;
220 		}
221 
222 		LOG_INF("SUBACK packet ID: %d", evt->param.suback.message_id);
223 		break;
224 
225 	case MQTT_EVT_PUBLISH:
226 		const struct mqtt_publish_param *p = &evt->param.publish;
227 
228 		if (p->message.topic.qos == MQTT_QOS_1_AT_LEAST_ONCE) {
229 			const struct mqtt_puback_param ack_param = {
230 				.message_id = p->message_id
231 			};
232 			mqtt_publish_qos1_ack(client, &ack_param);
233 		} else if (p->message.topic.qos == MQTT_QOS_2_EXACTLY_ONCE) {
234 			const struct mqtt_pubrec_param rec_param = {
235 				.message_id = p->message_id
236 			};
237 			mqtt_publish_qos2_receive(client, &rec_param);
238 		}
239 
240 		on_mqtt_publish(client, evt);
241 
242 	default:
243 		break;
244 	}
245 }
246 
247 /** Poll the MQTT socket for received data */
poll_mqtt_socket(struct mqtt_client * client,int timeout)248 static int poll_mqtt_socket(struct mqtt_client *client, int timeout)
249 {
250 	int rc;
251 
252 	prepare_fds(client);
253 
254 	if (nfds <= 0) {
255 		return -EINVAL;
256 	}
257 
258 	rc = zsock_poll(fds, nfds, timeout);
259 	if (rc < 0) {
260 		LOG_ERR("Socket poll error [%d]", rc);
261 	}
262 
263 	return rc;
264 }
265 
266 /** Retrieves a sensor sample and encodes it in JSON format */
get_mqtt_payload(struct mqtt_binstr * payload)267 static int get_mqtt_payload(struct mqtt_binstr *payload)
268 {
269 	int rc;
270 	struct sensor_sample sample;
271 
272 	rc = device_read_sensor(&sample);
273 	if (rc != 0) {
274 		LOG_ERR("Failed to get sensor sample [%d]", rc);
275 		return rc;
276 	}
277 
278 	rc = json_obj_encode_buf(sensor_sample_descr, ARRAY_SIZE(sensor_sample_descr),
279 					&sample, payload_buf, CONFIG_NET_SAMPLE_MQTT_PAYLOAD_SIZE);
280 	if (rc != 0) {
281 		LOG_ERR("Failed to encode JSON object [%d]", rc);
282 		return rc;
283 	}
284 
285 	payload->data = payload_buf;
286 	payload->len = strlen(payload->data);
287 
288 	return rc;
289 }
290 
app_mqtt_publish(struct mqtt_client * client)291 int app_mqtt_publish(struct mqtt_client *client)
292 {
293 	int rc;
294 	struct mqtt_publish_param param;
295 	struct mqtt_binstr payload;
296 	static uint16_t msg_id = 1;
297 	struct mqtt_topic topic = {
298 		.topic = {
299 			.utf8 = CONFIG_NET_SAMPLE_MQTT_PUB_TOPIC,
300 			.size = strlen(topic.topic.utf8)
301 		},
302 		.qos = IS_ENABLED(CONFIG_NET_SAMPLE_MQTT_QOS_0_AT_MOST_ONCE) ? 0 :
303 			(IS_ENABLED(CONFIG_NET_SAMPLE_MQTT_QOS_1_AT_LEAST_ONCE) ? 1 : 2)
304 	};
305 
306 	rc = get_mqtt_payload(&payload);
307 	if (rc != 0) {
308 		LOG_ERR("Failed to get MQTT payload [%d]", rc);
309 	}
310 
311 	param.message.topic = topic;
312 	param.message.payload = payload;
313 	param.message_id = msg_id++;
314 	param.dup_flag = 0;
315 	param.retain_flag = 0;
316 
317 	rc = mqtt_publish(client, &param);
318 	if (rc != 0) {
319 		LOG_ERR("MQTT Publish failed [%d]", rc);
320 	}
321 
322 	LOG_INF("Published to topic '%s', QoS %d",
323 			param.message.topic.topic.utf8,
324 			param.message.topic.qos);
325 
326 	return rc;
327 }
328 
app_mqtt_subscribe(struct mqtt_client * client)329 int app_mqtt_subscribe(struct mqtt_client *client)
330 {
331 	int rc;
332 	struct mqtt_topic sub_topics[] = {
333 		{
334 			.topic = {
335 				.utf8 = CONFIG_NET_SAMPLE_MQTT_SUB_TOPIC_CMD,
336 				.size = strlen(sub_topics->topic.utf8)
337 			},
338 			.qos = IS_ENABLED(CONFIG_NET_SAMPLE_MQTT_QOS_0_AT_MOST_ONCE) ? 0 :
339 				(IS_ENABLED(CONFIG_NET_SAMPLE_MQTT_QOS_1_AT_LEAST_ONCE) ? 1 : 2)
340 		}
341 	};
342 	const struct mqtt_subscription_list sub_list = {
343 		.list = sub_topics,
344 		.list_count = ARRAY_SIZE(sub_topics),
345 		.message_id = 5841u
346 	};
347 
348 	LOG_INF("Subscribing to %d topic(s)", sub_list.list_count);
349 
350 	rc = mqtt_subscribe(client, &sub_list);
351 	if (rc != 0) {
352 		LOG_ERR("MQTT Subscribe failed [%d]", rc);
353 	}
354 
355 	return rc;
356 }
357 
358 /** Process incoming MQTT data and keep the connection alive*/
app_mqtt_process(struct mqtt_client * client)359 int app_mqtt_process(struct mqtt_client *client)
360 {
361 	int rc;
362 
363 	rc = poll_mqtt_socket(client, mqtt_keepalive_time_left(client));
364 	if (rc != 0) {
365 		if (fds[0].revents & ZSOCK_POLLIN) {
366 			/* MQTT data received */
367 			rc = mqtt_input(client);
368 			if (rc != 0) {
369 				LOG_ERR("MQTT Input failed [%d]", rc);
370 				return rc;
371 			}
372 			/* Socket error */
373 			if (fds[0].revents & (ZSOCK_POLLHUP | ZSOCK_POLLERR)) {
374 				LOG_ERR("MQTT socket closed / error");
375 				return -ENOTCONN;
376 			}
377 		}
378 	} else {
379 		/* Socket poll timed out, time to call mqtt_live() */
380 		rc = mqtt_live(client);
381 		if (rc != 0) {
382 			LOG_ERR("MQTT Live failed [%d]", rc);
383 			return rc;
384 		}
385 	}
386 
387 	return 0;
388 }
389 
app_mqtt_run(struct mqtt_client * client)390 void app_mqtt_run(struct mqtt_client *client)
391 {
392 	int rc;
393 
394 	/* Subscribe to MQTT topics */
395 	app_mqtt_subscribe(client);
396 
397 	/* Thread will primarily remain in this loop */
398 	while (mqtt_connected) {
399 		rc = app_mqtt_process(client);
400 		if (rc != 0) {
401 			break;
402 		}
403 	}
404 	/* Gracefully close connection */
405 	mqtt_disconnect(client);
406 }
407 
app_mqtt_connect(struct mqtt_client * client)408 void app_mqtt_connect(struct mqtt_client *client)
409 {
410 	int rc = 0;
411 
412 	mqtt_connected = false;
413 
414 	/* Block until MQTT CONNACK event callback occurs */
415 	while (!mqtt_connected) {
416 		rc = mqtt_connect(client);
417 		if (rc != 0) {
418 			LOG_ERR("MQTT Connect failed [%d]", rc);
419 			k_msleep(MSECS_WAIT_RECONNECT);
420 			continue;
421 		}
422 
423 		/* Poll MQTT socket for response */
424 		rc = poll_mqtt_socket(client, MSECS_NET_POLL_TIMEOUT);
425 		if (rc > 0) {
426 			mqtt_input(client);
427 		}
428 
429 		if (!mqtt_connected) {
430 			mqtt_abort(client);
431 		}
432 	}
433 }
434 
app_mqtt_init(struct mqtt_client * client)435 int app_mqtt_init(struct mqtt_client *client)
436 {
437 	int rc;
438 	uint8_t broker_ip[NET_IPV4_ADDR_LEN];
439 	struct sockaddr_in *broker4;
440 	struct addrinfo *result;
441 	const struct addrinfo hints = {
442 		.ai_family = AF_INET,
443 		.ai_socktype = SOCK_STREAM
444 	};
445 
446 	/* Resolve IP address of MQTT broker */
447 	rc = getaddrinfo(CONFIG_NET_SAMPLE_MQTT_BROKER_HOSTNAME,
448 				CONFIG_NET_SAMPLE_MQTT_BROKER_PORT, &hints, &result);
449 	if (rc != 0) {
450 		LOG_ERR("Failed to resolve broker hostname [%s]", gai_strerror(rc));
451 		return -EIO;
452 	}
453 	if (result == NULL) {
454 		LOG_ERR("Broker address not found");
455 		return -ENOENT;
456 	}
457 
458 	broker4 = (struct sockaddr_in *)&broker;
459 	broker4->sin_addr.s_addr = ((struct sockaddr_in *)result->ai_addr)->sin_addr.s_addr;
460 	broker4->sin_family = AF_INET;
461 	broker4->sin_port = ((struct sockaddr_in *)result->ai_addr)->sin_port;
462 	freeaddrinfo(result);
463 
464 	/* Log resolved IP address */
465 	inet_ntop(AF_INET, &broker4->sin_addr.s_addr, broker_ip, sizeof(broker_ip));
466 	LOG_INF("Connecting to MQTT broker @ %s", broker_ip);
467 
468 	/* MQTT client configuration */
469 	init_mqtt_client_id();
470 	mqtt_client_init(client);
471 	client->broker = &broker;
472 	client->evt_cb = mqtt_event_handler;
473 	client->client_id.utf8 = client_id;
474 	client->client_id.size = strlen(client->client_id.utf8);
475 	client->password = NULL;
476 	client->user_name = NULL;
477 	client->protocol_version = MQTT_VERSION_3_1_1;
478 
479 	/* MQTT buffers configuration */
480 	client->rx_buf = rx_buffer;
481 	client->rx_buf_size = sizeof(rx_buffer);
482 	client->tx_buf = tx_buffer;
483 	client->tx_buf_size = sizeof(tx_buffer);
484 
485 	/* MQTT transport configuration */
486 #if defined(CONFIG_MQTT_LIB_TLS)
487 	struct mqtt_sec_config *tls_config;
488 
489 	client->transport.type = MQTT_TRANSPORT_SECURE;
490 
491 	rc = tls_init();
492 	if (rc != 0) {
493 		LOG_ERR("TLS init error");
494 		return rc;
495 	}
496 
497 	tls_config = &client->transport.tls.config;
498 	tls_config->peer_verify = TLS_PEER_VERIFY_REQUIRED;
499 	tls_config->cipher_list = NULL;
500 	tls_config->sec_tag_list = m_sec_tags;
501 	tls_config->sec_tag_count = ARRAY_SIZE(m_sec_tags);
502 #if defined(CONFIG_MBEDTLS_SERVER_NAME_INDICATION)
503 	tls_config->hostname = TLS_SNI_HOSTNAME;
504 #else
505 	tls_config->hostname = NULL;
506 #endif /* CONFIG_MBEDTLS_SERVER_NAME_INDICATION */
507 #endif /* CONFIG_MQTT_LIB_TLS */
508 
509 	return rc;
510 }
511