1 /*
2  * Copyright (c) 2019 Intel Corporation
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(mqtt_azure, LOG_LEVEL_DBG);
9 
10 #include <zephyr/kernel.h>
11 #include <zephyr/net/socket.h>
12 #include <zephyr/net/mqtt.h>
13 
14 #include <zephyr/sys/printk.h>
15 #include <zephyr/random/random.h>
16 #include <string.h>
17 #include <errno.h>
18 
19 #include "config.h"
20 #include "test_certs.h"
21 
22 /* Buffers for MQTT client. */
23 static uint8_t rx_buffer[APP_MQTT_BUFFER_SIZE];
24 static uint8_t tx_buffer[APP_MQTT_BUFFER_SIZE];
25 
26 /* The mqtt client struct */
27 static struct mqtt_client client_ctx;
28 
29 /* MQTT Broker details. */
30 static struct sockaddr_storage broker;
31 
32 #if defined(CONFIG_SOCKS)
33 static struct sockaddr socks5_proxy;
34 #endif
35 
36 /* Socket Poll */
37 static struct zsock_pollfd fds[1];
38 static int nfds;
39 
40 static bool mqtt_connected;
41 
42 static struct k_work_delayable pub_message;
43 #if defined(CONFIG_NET_DHCPV4)
44 static struct k_work_delayable check_network_conn;
45 
46 /* Network Management events */
47 #define L4_EVENT_MASK (NET_EVENT_L4_CONNECTED | NET_EVENT_L4_DISCONNECTED)
48 
49 static struct net_mgmt_event_callback l4_mgmt_cb;
50 #endif
51 
52 #if defined(CONFIG_DNS_RESOLVER)
53 static struct zsock_addrinfo hints;
54 static struct zsock_addrinfo *haddr;
55 #endif
56 
57 static K_SEM_DEFINE(mqtt_start, 0, 1);
58 
59 /* Application TLS configuration details */
60 #define TLS_SNI_HOSTNAME CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME
61 #define APP_CA_CERT_TAG 1
62 
63 static const sec_tag_t m_sec_tags[] = {
64 	APP_CA_CERT_TAG,
65 };
66 
67 static uint8_t devbound_topic[] = "devices/" MQTT_CLIENTID "/messages/devicebound/#";
68 static struct mqtt_topic subs_topic;
69 static struct mqtt_subscription_list subs_list;
70 
71 static void mqtt_event_handler(struct mqtt_client *const client,
72 			       const struct mqtt_evt *evt);
73 
tls_init(void)74 static int tls_init(void)
75 {
76 	int err;
77 
78 	err = tls_credential_add(APP_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE,
79 				 ca_certificate, sizeof(ca_certificate));
80 	if (err < 0) {
81 		LOG_ERR("Failed to register public certificate: %d", err);
82 		return err;
83 	}
84 
85 	return err;
86 }
87 
prepare_fds(struct mqtt_client * client)88 static void prepare_fds(struct mqtt_client *client)
89 {
90 	if (client->transport.type == MQTT_TRANSPORT_SECURE) {
91 		fds[0].fd = client->transport.tls.sock;
92 	}
93 
94 	fds[0].events = ZSOCK_POLLIN;
95 	nfds = 1;
96 }
97 
clear_fds(void)98 static void clear_fds(void)
99 {
100 	nfds = 0;
101 }
102 
wait(int timeout)103 static int wait(int timeout)
104 {
105 	int rc = -EINVAL;
106 
107 	if (nfds <= 0) {
108 		return rc;
109 	}
110 
111 	rc = zsock_poll(fds, nfds, timeout);
112 	if (rc < 0) {
113 		LOG_ERR("poll error: %d", errno);
114 		return -errno;
115 	}
116 
117 	return rc;
118 }
119 
broker_init(void)120 static void broker_init(void)
121 {
122 	struct sockaddr_in *broker4 = (struct sockaddr_in *)&broker;
123 
124 	broker4->sin_family = AF_INET;
125 	broker4->sin_port = htons(SERVER_PORT);
126 
127 #if defined(CONFIG_DNS_RESOLVER)
128 	net_ipaddr_copy(&broker4->sin_addr,
129 			&net_sin(haddr->ai_addr)->sin_addr);
130 #else
131 	zsock_inet_pton(AF_INET, SERVER_ADDR, &broker4->sin_addr);
132 #endif
133 
134 #if defined(CONFIG_SOCKS)
135 	struct sockaddr_in *proxy4 = (struct sockaddr_in *)&socks5_proxy;
136 
137 	proxy4->sin_family = AF_INET;
138 	proxy4->sin_port = htons(SOCKS5_PROXY_PORT);
139 	zsock_inet_pton(AF_INET, SOCKS5_PROXY_ADDR, &proxy4->sin_addr);
140 #endif
141 }
142 
client_init(struct mqtt_client * client)143 static void client_init(struct mqtt_client *client)
144 {
145 	static struct mqtt_utf8 password;
146 	static struct mqtt_utf8 username;
147 	struct mqtt_sec_config *tls_config;
148 
149 	mqtt_client_init(client);
150 
151 	broker_init();
152 
153 	/* MQTT client configuration */
154 	client->broker = &broker;
155 	client->evt_cb = mqtt_event_handler;
156 
157 	client->client_id.utf8 = (uint8_t *)MQTT_CLIENTID;
158 	client->client_id.size = strlen(MQTT_CLIENTID);
159 
160 	password.utf8 = (uint8_t *)CONFIG_SAMPLE_CLOUD_AZURE_PASSWORD;
161 	password.size = strlen(CONFIG_SAMPLE_CLOUD_AZURE_PASSWORD);
162 
163 	client->password = &password;
164 
165 	username.utf8 = (uint8_t *)CONFIG_SAMPLE_CLOUD_AZURE_USERNAME;
166 	username.size = strlen(CONFIG_SAMPLE_CLOUD_AZURE_USERNAME);
167 
168 	client->user_name = &username;
169 
170 	client->protocol_version = MQTT_VERSION_3_1_1;
171 
172 	/* MQTT buffers configuration */
173 	client->rx_buf = rx_buffer;
174 	client->rx_buf_size = sizeof(rx_buffer);
175 	client->tx_buf = tx_buffer;
176 	client->tx_buf_size = sizeof(tx_buffer);
177 
178 	/* MQTT transport configuration */
179 	client->transport.type = MQTT_TRANSPORT_SECURE;
180 
181 	tls_config = &client->transport.tls.config;
182 
183 	tls_config->peer_verify = TLS_PEER_VERIFY_REQUIRED;
184 	tls_config->cipher_list = NULL;
185 	tls_config->sec_tag_list = m_sec_tags;
186 	tls_config->sec_tag_count = ARRAY_SIZE(m_sec_tags);
187 	tls_config->hostname = TLS_SNI_HOSTNAME;
188 
189 #if defined(CONFIG_SOCKS)
190 	mqtt_client_set_proxy(client, &socks5_proxy,
191 			      socks5_proxy.sa_family == AF_INET ?
192 			      sizeof(struct sockaddr_in) :
193 			      sizeof(struct sockaddr_in6));
194 #endif
195 }
196 
mqtt_event_handler(struct mqtt_client * const client,const struct mqtt_evt * evt)197 static void mqtt_event_handler(struct mqtt_client *const client,
198 			       const struct mqtt_evt *evt)
199 {
200 	struct mqtt_puback_param puback;
201 	uint8_t data[33];
202 	int len;
203 	int bytes_read;
204 
205 	switch (evt->type) {
206 	case MQTT_EVT_SUBACK:
207 		LOG_INF("SUBACK packet id: %u", evt->param.suback.message_id);
208 		break;
209 
210 	case MQTT_EVT_UNSUBACK:
211 		LOG_INF("UNSUBACK packet id: %u", evt->param.suback.message_id);
212 		break;
213 
214 	case MQTT_EVT_CONNACK:
215 		if (evt->result) {
216 			LOG_ERR("MQTT connect failed %d", evt->result);
217 			break;
218 		}
219 
220 		mqtt_connected = true;
221 		LOG_DBG("MQTT client connected!");
222 		break;
223 
224 	case MQTT_EVT_DISCONNECT:
225 		LOG_DBG("MQTT client disconnected %d", evt->result);
226 
227 		mqtt_connected = false;
228 		clear_fds();
229 		break;
230 
231 	case MQTT_EVT_PUBACK:
232 		if (evt->result) {
233 			LOG_ERR("MQTT PUBACK error %d", evt->result);
234 			break;
235 		}
236 
237 		LOG_DBG("PUBACK packet id: %u\n", evt->param.puback.message_id);
238 		break;
239 
240 	case MQTT_EVT_PUBLISH:
241 		len = evt->param.publish.message.payload.len;
242 
243 		LOG_INF("MQTT publish received %d, %d bytes", evt->result, len);
244 		LOG_INF(" id: %d, qos: %d", evt->param.publish.message_id,
245 			evt->param.publish.message.topic.qos);
246 
247 		while (len) {
248 			bytes_read = mqtt_read_publish_payload(&client_ctx,
249 					data,
250 					len >= sizeof(data) - 1 ?
251 					sizeof(data) - 1 : len);
252 			if (bytes_read < 0 && bytes_read != -EAGAIN) {
253 				LOG_ERR("failure to read payload");
254 				break;
255 			}
256 
257 			data[bytes_read] = '\0';
258 			LOG_INF("   payload: %s", data);
259 			len -= bytes_read;
260 		}
261 
262 		puback.message_id = evt->param.publish.message_id;
263 		mqtt_publish_qos1_ack(&client_ctx, &puback);
264 		break;
265 
266 	default:
267 		LOG_DBG("Unhandled MQTT event %d", evt->type);
268 		break;
269 	}
270 }
271 
subscribe(struct mqtt_client * client)272 static void subscribe(struct mqtt_client *client)
273 {
274 	int err;
275 
276 	/* subscribe */
277 	subs_topic.topic.utf8 = devbound_topic;
278 	subs_topic.topic.size = strlen(devbound_topic);
279 	subs_list.list = &subs_topic;
280 	subs_list.list_count = 1U;
281 	subs_list.message_id = 1U;
282 
283 	err = mqtt_subscribe(client, &subs_list);
284 	if (err) {
285 		LOG_ERR("Failed on topic %s", devbound_topic);
286 	}
287 }
288 
publish(struct mqtt_client * client,enum mqtt_qos qos)289 static int publish(struct mqtt_client *client, enum mqtt_qos qos)
290 {
291 	char payload[] = "{id=123}";
292 	char evt_topic[] = "devices/" MQTT_CLIENTID "/messages/events/";
293 	uint8_t len = strlen(evt_topic);
294 	struct mqtt_publish_param param;
295 
296 	param.message.topic.qos = qos;
297 	param.message.topic.topic.utf8 = (uint8_t *)evt_topic;
298 	param.message.topic.topic.size = len;
299 	param.message.payload.data = payload;
300 	param.message.payload.len = strlen(payload);
301 	param.message_id = sys_rand32_get();
302 	param.dup_flag = 0U;
303 	param.retain_flag = 0U;
304 
305 	return mqtt_publish(client, &param);
306 }
307 
poll_mqtt(void)308 static void poll_mqtt(void)
309 {
310 	int rc;
311 
312 	while (mqtt_connected) {
313 		rc = wait(SYS_FOREVER_MS);
314 		if (rc > 0) {
315 			mqtt_input(&client_ctx);
316 		}
317 	}
318 }
319 
320 /* Random time between 10 - 15 seconds
321  * If you prefer to have this value more than CONFIG_MQTT_KEEPALIVE,
322  * then keep the application connection live by calling mqtt_live()
323  * in regular intervals.
324  */
timeout_for_publish(void)325 static uint8_t timeout_for_publish(void)
326 {
327 	return (10 + sys_rand32_get() % 5);
328 }
329 
publish_timeout(struct k_work * work)330 static void publish_timeout(struct k_work *work)
331 {
332 	int rc;
333 
334 	if (!mqtt_connected) {
335 		return;
336 	}
337 
338 	rc = publish(&client_ctx, MQTT_QOS_1_AT_LEAST_ONCE);
339 	if (rc) {
340 		LOG_ERR("mqtt_publish ERROR");
341 		goto end;
342 	}
343 
344 	LOG_DBG("mqtt_publish OK");
345 end:
346 	k_work_reschedule(&pub_message, K_SECONDS(timeout_for_publish()));
347 }
348 
try_to_connect(struct mqtt_client * client)349 static int try_to_connect(struct mqtt_client *client)
350 {
351 	uint8_t retries = 3U;
352 	int rc;
353 
354 	LOG_DBG("attempting to connect...");
355 
356 	while (retries--) {
357 		client_init(client);
358 
359 		rc = mqtt_connect(client);
360 		if (rc) {
361 			LOG_ERR("mqtt_connect failed %d", rc);
362 			continue;
363 		}
364 
365 		prepare_fds(client);
366 
367 		rc = wait(APP_SLEEP_MSECS);
368 		if (rc < 0) {
369 			mqtt_abort(client);
370 			return rc;
371 		}
372 
373 		mqtt_input(client);
374 
375 		if (mqtt_connected) {
376 			subscribe(client);
377 			k_work_reschedule(&pub_message,
378 					  K_SECONDS(timeout_for_publish()));
379 			return 0;
380 		}
381 
382 		mqtt_abort(client);
383 
384 		wait(10 * MSEC_PER_SEC);
385 	}
386 
387 	return -EINVAL;
388 }
389 
390 #if defined(CONFIG_DNS_RESOLVER)
get_mqtt_broker_addrinfo(void)391 static int get_mqtt_broker_addrinfo(void)
392 {
393 	int retries = 3;
394 	int rc = -EINVAL;
395 
396 	while (retries--) {
397 		hints.ai_family = AF_INET;
398 		hints.ai_socktype = SOCK_STREAM;
399 		hints.ai_protocol = 0;
400 
401 		rc = zsock_getaddrinfo(CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME, "8883",
402 				       &hints, &haddr);
403 		if (rc == 0) {
404 			LOG_INF("DNS resolved for %s:%d",
405 			CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME,
406 			CONFIG_SAMPLE_CLOUD_AZURE_SERVER_PORT);
407 
408 			return 0;
409 		}
410 
411 		LOG_ERR("DNS not resolved for %s:%d, retrying",
412 			CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME,
413 			CONFIG_SAMPLE_CLOUD_AZURE_SERVER_PORT);
414 	}
415 
416 	return rc;
417 }
418 #endif
419 
connect_to_cloud_and_publish(void)420 static void connect_to_cloud_and_publish(void)
421 {
422 	int rc = -EINVAL;
423 
424 #if defined(CONFIG_NET_DHCPV4)
425 	while (true) {
426 		k_sem_take(&mqtt_start, K_FOREVER);
427 #endif
428 #if defined(CONFIG_DNS_RESOLVER)
429 		rc = get_mqtt_broker_addrinfo();
430 		if (rc) {
431 			return;
432 		}
433 #endif
434 		rc = try_to_connect(&client_ctx);
435 		if (rc) {
436 			return;
437 		}
438 
439 		poll_mqtt();
440 #if defined(CONFIG_NET_DHCPV4)
441 	}
442 #endif
443 }
444 
445 /* DHCP tries to renew the address after interface is down and up.
446  * If DHCPv4 address renewal is success, then it doesn't generate
447  * any event. We have to monitor this way.
448  * If DHCPv4 attempts exceeds maximum number, it will delete iface
449  * address and attempts for new request. In this case we can rely
450  * on IPV4_ADDR_ADD event.
451  */
452 #if defined(CONFIG_NET_DHCPV4)
check_network_connection(struct k_work * work)453 static void check_network_connection(struct k_work *work)
454 {
455 	struct net_if *iface;
456 
457 	if (mqtt_connected) {
458 		return;
459 	}
460 
461 	iface = net_if_get_default();
462 	if (!iface) {
463 		goto end;
464 	}
465 
466 	if (iface->config.dhcpv4.state == NET_DHCPV4_BOUND) {
467 		k_sem_give(&mqtt_start);
468 		return;
469 	}
470 
471 	LOG_INF("waiting for DHCP to acquire addr");
472 
473 end:
474 	k_work_reschedule(&check_network_conn, K_SECONDS(3));
475 }
476 #endif
477 
478 #if defined(CONFIG_NET_DHCPV4)
abort_mqtt_connection(void)479 static void abort_mqtt_connection(void)
480 {
481 	if (mqtt_connected) {
482 		mqtt_connected = false;
483 		mqtt_abort(&client_ctx);
484 		k_work_cancel_delayable(&pub_message);
485 	}
486 }
487 
l4_event_handler(struct net_mgmt_event_callback * cb,uint32_t mgmt_event,struct net_if * iface)488 static void l4_event_handler(struct net_mgmt_event_callback *cb,
489 			     uint32_t mgmt_event, struct net_if *iface)
490 {
491 	if ((mgmt_event & L4_EVENT_MASK) != mgmt_event) {
492 		return;
493 	}
494 
495 	if (mgmt_event == NET_EVENT_L4_CONNECTED) {
496 		/* Wait for DHCP to be back in BOUND state */
497 		k_work_reschedule(&check_network_conn, K_SECONDS(3));
498 
499 		return;
500 	}
501 
502 	if (mgmt_event == NET_EVENT_L4_DISCONNECTED) {
503 		abort_mqtt_connection();
504 		k_work_cancel_delayable(&check_network_conn);
505 
506 		return;
507 	}
508 }
509 #endif
510 
main(void)511 int main(void)
512 {
513 	int rc;
514 
515 	LOG_DBG("Waiting for network to setup...");
516 
517 	rc = tls_init();
518 	if (rc) {
519 		return 0;
520 	}
521 
522 	k_work_init_delayable(&pub_message, publish_timeout);
523 
524 #if defined(CONFIG_NET_DHCPV4)
525 	k_work_init_delayable(&check_network_conn, check_network_connection);
526 
527 	net_mgmt_init_event_callback(&l4_mgmt_cb, l4_event_handler,
528 				     L4_EVENT_MASK);
529 	net_mgmt_add_event_callback(&l4_mgmt_cb);
530 #endif
531 
532 	connect_to_cloud_and_publish();
533 	return 0;
534 }
535