1 /*
2  * Copyright (c) 2019 Intel Corporation
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(mqtt_azure, LOG_LEVEL_DBG);
9 
10 #include <zephyr/posix/netinet/in.h>
11 #include <zephyr/posix/sys/socket.h>
12 #include <zephyr/posix/arpa/inet.h>
13 #include <zephyr/posix/unistd.h>
14 #include <zephyr/posix/netdb.h>
15 #include <zephyr/posix/poll.h>
16 
17 #include <zephyr/kernel.h>
18 #include <zephyr/net/net_if.h>
19 #include <zephyr/net/socket.h>
20 #include <zephyr/net/mqtt.h>
21 
22 #include <zephyr/sys/printk.h>
23 #include <zephyr/random/random.h>
24 #include <string.h>
25 #include <errno.h>
26 
27 #include "config.h"
28 #include "test_certs.h"
29 
30 /* Buffers for MQTT client. */
31 static uint8_t rx_buffer[APP_MQTT_BUFFER_SIZE];
32 static uint8_t tx_buffer[APP_MQTT_BUFFER_SIZE];
33 
34 /* The mqtt client struct */
35 static struct mqtt_client client_ctx;
36 
37 /* MQTT Broker details. */
38 static struct sockaddr_storage broker;
39 
40 #if defined(CONFIG_SOCKS)
41 static struct sockaddr socks5_proxy;
42 #endif
43 
44 /* Socket Poll */
45 static struct pollfd fds[1];
46 static int nfds;
47 
48 static bool mqtt_connected;
49 
50 static struct k_work_delayable pub_message;
51 #if defined(CONFIG_NET_DHCPV4)
52 static struct k_work_delayable check_network_conn;
53 
54 /* Network Management events */
55 #define L4_EVENT_MASK (NET_EVENT_L4_CONNECTED | NET_EVENT_L4_DISCONNECTED)
56 
57 static struct net_mgmt_event_callback l4_mgmt_cb;
58 #endif
59 
60 #if defined(CONFIG_DNS_RESOLVER)
61 static struct addrinfo hints;
62 static struct addrinfo *haddr;
63 #endif
64 
65 static K_SEM_DEFINE(mqtt_start, 0, 1);
66 
67 /* Application TLS configuration details */
68 #define TLS_SNI_HOSTNAME CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME
69 #define APP_CA_CERT_TAG 1
70 
71 static const sec_tag_t m_sec_tags[] = {
72 	APP_CA_CERT_TAG,
73 };
74 
75 static uint8_t devbound_topic[] = "devices/" MQTT_CLIENTID "/messages/devicebound/#";
76 static struct mqtt_topic subs_topic;
77 static struct mqtt_subscription_list subs_list;
78 
79 static void mqtt_event_handler(struct mqtt_client *const client,
80 			       const struct mqtt_evt *evt);
81 
tls_init(void)82 static int tls_init(void)
83 {
84 	int err;
85 
86 	err = tls_credential_add(APP_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE,
87 				 ca_certificate, sizeof(ca_certificate));
88 	if (err < 0) {
89 		LOG_ERR("Failed to register public certificate: %d", err);
90 		return err;
91 	}
92 
93 	return err;
94 }
95 
prepare_fds(struct mqtt_client * client)96 static void prepare_fds(struct mqtt_client *client)
97 {
98 	if (client->transport.type == MQTT_TRANSPORT_SECURE) {
99 		fds[0].fd = client->transport.tls.sock;
100 	}
101 
102 	fds[0].events = POLLIN;
103 	nfds = 1;
104 }
105 
clear_fds(void)106 static void clear_fds(void)
107 {
108 	nfds = 0;
109 }
110 
wait(int timeout)111 static int wait(int timeout)
112 {
113 	int rc = -EINVAL;
114 
115 	if (nfds <= 0) {
116 		return rc;
117 	}
118 
119 	rc = poll(fds, nfds, timeout);
120 	if (rc < 0) {
121 		LOG_ERR("poll error: %d", errno);
122 		return -errno;
123 	}
124 
125 	return rc;
126 }
127 
broker_init(void)128 static void broker_init(void)
129 {
130 	struct sockaddr_in *broker4 = (struct sockaddr_in *)&broker;
131 
132 	broker4->sin_family = AF_INET;
133 	broker4->sin_port = htons(SERVER_PORT);
134 
135 #if defined(CONFIG_DNS_RESOLVER)
136 	net_ipaddr_copy(&broker4->sin_addr,
137 			&net_sin(haddr->ai_addr)->sin_addr);
138 #else
139 	inet_pton(AF_INET, SERVER_ADDR, &broker4->sin_addr);
140 #endif
141 
142 #if defined(CONFIG_SOCKS)
143 	struct sockaddr_in *proxy4 = (struct sockaddr_in *)&socks5_proxy;
144 
145 	proxy4->sin_family = AF_INET;
146 	proxy4->sin_port = htons(SOCKS5_PROXY_PORT);
147 	inet_pton(AF_INET, SOCKS5_PROXY_ADDR, &proxy4->sin_addr);
148 #endif
149 }
150 
client_init(struct mqtt_client * client)151 static void client_init(struct mqtt_client *client)
152 {
153 	static struct mqtt_utf8 password;
154 	static struct mqtt_utf8 username;
155 	struct mqtt_sec_config *tls_config;
156 
157 	mqtt_client_init(client);
158 
159 	broker_init();
160 
161 	/* MQTT client configuration */
162 	client->broker = &broker;
163 	client->evt_cb = mqtt_event_handler;
164 
165 	client->client_id.utf8 = (uint8_t *)MQTT_CLIENTID;
166 	client->client_id.size = strlen(MQTT_CLIENTID);
167 
168 	password.utf8 = (uint8_t *)CONFIG_SAMPLE_CLOUD_AZURE_PASSWORD;
169 	password.size = strlen(CONFIG_SAMPLE_CLOUD_AZURE_PASSWORD);
170 
171 	client->password = &password;
172 
173 	username.utf8 = (uint8_t *)CONFIG_SAMPLE_CLOUD_AZURE_USERNAME;
174 	username.size = strlen(CONFIG_SAMPLE_CLOUD_AZURE_USERNAME);
175 
176 	client->user_name = &username;
177 
178 	client->protocol_version = MQTT_VERSION_3_1_1;
179 
180 	/* MQTT buffers configuration */
181 	client->rx_buf = rx_buffer;
182 	client->rx_buf_size = sizeof(rx_buffer);
183 	client->tx_buf = tx_buffer;
184 	client->tx_buf_size = sizeof(tx_buffer);
185 
186 	/* MQTT transport configuration */
187 	client->transport.type = MQTT_TRANSPORT_SECURE;
188 
189 	tls_config = &client->transport.tls.config;
190 
191 	tls_config->peer_verify = TLS_PEER_VERIFY_REQUIRED;
192 	tls_config->cipher_list = NULL;
193 	tls_config->sec_tag_list = m_sec_tags;
194 	tls_config->sec_tag_count = ARRAY_SIZE(m_sec_tags);
195 	tls_config->hostname = TLS_SNI_HOSTNAME;
196 
197 #if defined(CONFIG_SOCKS)
198 	mqtt_client_set_proxy(client, &socks5_proxy,
199 			      socks5_proxy.sa_family == AF_INET ?
200 			      sizeof(struct sockaddr_in) :
201 			      sizeof(struct sockaddr_in6));
202 #endif
203 }
204 
mqtt_event_handler(struct mqtt_client * const client,const struct mqtt_evt * evt)205 static void mqtt_event_handler(struct mqtt_client *const client,
206 			       const struct mqtt_evt *evt)
207 {
208 	struct mqtt_puback_param puback;
209 	uint8_t data[33];
210 	int len;
211 	int bytes_read;
212 
213 	switch (evt->type) {
214 	case MQTT_EVT_SUBACK:
215 		LOG_INF("SUBACK packet id: %u", evt->param.suback.message_id);
216 		break;
217 
218 	case MQTT_EVT_UNSUBACK:
219 		LOG_INF("UNSUBACK packet id: %u", evt->param.suback.message_id);
220 		break;
221 
222 	case MQTT_EVT_CONNACK:
223 		if (evt->result) {
224 			LOG_ERR("MQTT connect failed %d", evt->result);
225 			break;
226 		}
227 
228 		mqtt_connected = true;
229 		LOG_DBG("MQTT client connected!");
230 		break;
231 
232 	case MQTT_EVT_DISCONNECT:
233 		LOG_DBG("MQTT client disconnected %d", evt->result);
234 
235 		mqtt_connected = false;
236 		clear_fds();
237 		break;
238 
239 	case MQTT_EVT_PUBACK:
240 		if (evt->result) {
241 			LOG_ERR("MQTT PUBACK error %d", evt->result);
242 			break;
243 		}
244 
245 		LOG_DBG("PUBACK packet id: %u\n", evt->param.puback.message_id);
246 		break;
247 
248 	case MQTT_EVT_PUBLISH:
249 		len = evt->param.publish.message.payload.len;
250 
251 		LOG_INF("MQTT publish received %d, %d bytes", evt->result, len);
252 		LOG_INF(" id: %d, qos: %d", evt->param.publish.message_id,
253 			evt->param.publish.message.topic.qos);
254 
255 		while (len) {
256 			bytes_read = mqtt_read_publish_payload(&client_ctx,
257 					data,
258 					len >= sizeof(data) - 1 ?
259 					sizeof(data) - 1 : len);
260 			if (bytes_read < 0 && bytes_read != -EAGAIN) {
261 				LOG_ERR("failure to read payload");
262 				break;
263 			}
264 
265 			data[bytes_read] = '\0';
266 			LOG_INF("   payload: %s", data);
267 			len -= bytes_read;
268 		}
269 
270 		puback.message_id = evt->param.publish.message_id;
271 		mqtt_publish_qos1_ack(&client_ctx, &puback);
272 		break;
273 
274 	default:
275 		LOG_DBG("Unhandled MQTT event %d", evt->type);
276 		break;
277 	}
278 }
279 
subscribe(struct mqtt_client * client)280 static void subscribe(struct mqtt_client *client)
281 {
282 	int err;
283 
284 	/* subscribe */
285 	subs_topic.topic.utf8 = devbound_topic;
286 	subs_topic.topic.size = strlen(devbound_topic);
287 	subs_list.list = &subs_topic;
288 	subs_list.list_count = 1U;
289 	subs_list.message_id = 1U;
290 
291 	err = mqtt_subscribe(client, &subs_list);
292 	if (err) {
293 		LOG_ERR("Failed on topic %s", devbound_topic);
294 	}
295 }
296 
publish(struct mqtt_client * client,enum mqtt_qos qos)297 static int publish(struct mqtt_client *client, enum mqtt_qos qos)
298 {
299 	char payload[] = "{id=123}";
300 	char evt_topic[] = "devices/" MQTT_CLIENTID "/messages/events/";
301 	uint8_t len = strlen(evt_topic);
302 	struct mqtt_publish_param param;
303 
304 	param.message.topic.qos = qos;
305 	param.message.topic.topic.utf8 = (uint8_t *)evt_topic;
306 	param.message.topic.topic.size = len;
307 	param.message.payload.data = payload;
308 	param.message.payload.len = strlen(payload);
309 	param.message_id = sys_rand16_get();
310 	param.dup_flag = 0U;
311 	param.retain_flag = 0U;
312 
313 	return mqtt_publish(client, &param);
314 }
315 
poll_mqtt(void)316 static void poll_mqtt(void)
317 {
318 	int rc;
319 
320 	while (mqtt_connected) {
321 		rc = wait(SYS_FOREVER_MS);
322 		if (rc > 0) {
323 			mqtt_input(&client_ctx);
324 		}
325 	}
326 }
327 
328 /* Random time between 10 - 15 seconds
329  * If you prefer to have this value more than CONFIG_MQTT_KEEPALIVE,
330  * then keep the application connection live by calling mqtt_live()
331  * in regular intervals.
332  */
timeout_for_publish(void)333 static uint8_t timeout_for_publish(void)
334 {
335 	return (10 + sys_rand8_get() % 5);
336 }
337 
publish_timeout(struct k_work * work)338 static void publish_timeout(struct k_work *work)
339 {
340 	int rc;
341 
342 	if (!mqtt_connected) {
343 		return;
344 	}
345 
346 	rc = publish(&client_ctx, MQTT_QOS_1_AT_LEAST_ONCE);
347 	if (rc) {
348 		LOG_ERR("mqtt_publish ERROR");
349 		goto end;
350 	}
351 
352 	LOG_DBG("mqtt_publish OK");
353 end:
354 	k_work_reschedule(&pub_message, K_SECONDS(timeout_for_publish()));
355 }
356 
try_to_connect(struct mqtt_client * client)357 static int try_to_connect(struct mqtt_client *client)
358 {
359 	uint8_t retries = 3U;
360 	int rc;
361 
362 	LOG_DBG("attempting to connect...");
363 
364 	while (retries--) {
365 		client_init(client);
366 
367 		rc = mqtt_connect(client);
368 		if (rc) {
369 			LOG_ERR("mqtt_connect failed %d", rc);
370 			continue;
371 		}
372 
373 		prepare_fds(client);
374 
375 		rc = wait(APP_SLEEP_MSECS);
376 		if (rc < 0) {
377 			mqtt_abort(client);
378 			return rc;
379 		}
380 
381 		mqtt_input(client);
382 
383 		if (mqtt_connected) {
384 			subscribe(client);
385 			k_work_reschedule(&pub_message,
386 					  K_SECONDS(timeout_for_publish()));
387 			return 0;
388 		}
389 
390 		mqtt_abort(client);
391 
392 		wait(10 * MSEC_PER_SEC);
393 	}
394 
395 	return -EINVAL;
396 }
397 
398 #if defined(CONFIG_DNS_RESOLVER)
get_mqtt_broker_addrinfo(void)399 static int get_mqtt_broker_addrinfo(void)
400 {
401 	int retries = 3;
402 	int rc = -EINVAL;
403 
404 	while (retries--) {
405 		hints.ai_family = AF_INET;
406 		hints.ai_socktype = SOCK_STREAM;
407 		hints.ai_protocol = 0;
408 
409 		rc = getaddrinfo(CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME, "8883",
410 				 &hints, &haddr);
411 		if (rc == 0) {
412 			LOG_INF("DNS resolved for %s:%d",
413 			CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME,
414 			CONFIG_SAMPLE_CLOUD_AZURE_SERVER_PORT);
415 
416 			return 0;
417 		}
418 
419 		LOG_ERR("DNS not resolved for %s:%d, retrying",
420 			CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME,
421 			CONFIG_SAMPLE_CLOUD_AZURE_SERVER_PORT);
422 	}
423 
424 	return rc;
425 }
426 #endif
427 
connect_to_cloud_and_publish(void)428 static void connect_to_cloud_and_publish(void)
429 {
430 	int rc = -EINVAL;
431 
432 #if defined(CONFIG_NET_DHCPV4)
433 	while (true) {
434 		k_sem_take(&mqtt_start, K_FOREVER);
435 #endif
436 #if defined(CONFIG_DNS_RESOLVER)
437 		rc = get_mqtt_broker_addrinfo();
438 		if (rc) {
439 			return;
440 		}
441 #endif
442 		rc = try_to_connect(&client_ctx);
443 		if (rc) {
444 			return;
445 		}
446 
447 		poll_mqtt();
448 #if defined(CONFIG_NET_DHCPV4)
449 	}
450 #endif
451 }
452 
453 /* DHCP tries to renew the address after interface is down and up.
454  * If DHCPv4 address renewal is success, then it doesn't generate
455  * any event. We have to monitor this way.
456  * If DHCPv4 attempts exceeds maximum number, it will delete iface
457  * address and attempts for new request. In this case we can rely
458  * on IPV4_ADDR_ADD event.
459  */
460 #if defined(CONFIG_NET_DHCPV4)
check_network_connection(struct k_work * work)461 static void check_network_connection(struct k_work *work)
462 {
463 	struct net_if *iface;
464 
465 	if (mqtt_connected) {
466 		return;
467 	}
468 
469 	iface = net_if_get_default();
470 	if (!iface) {
471 		goto end;
472 	}
473 
474 	if (iface->config.dhcpv4.state == NET_DHCPV4_BOUND) {
475 		k_sem_give(&mqtt_start);
476 		return;
477 	}
478 
479 	LOG_INF("waiting for DHCP to acquire addr");
480 
481 end:
482 	k_work_reschedule(&check_network_conn, K_SECONDS(3));
483 }
484 #endif
485 
486 #if defined(CONFIG_NET_DHCPV4)
abort_mqtt_connection(void)487 static void abort_mqtt_connection(void)
488 {
489 	if (mqtt_connected) {
490 		mqtt_connected = false;
491 		mqtt_abort(&client_ctx);
492 		k_work_cancel_delayable(&pub_message);
493 	}
494 }
495 
l4_event_handler(struct net_mgmt_event_callback * cb,uint64_t mgmt_event,struct net_if * iface)496 static void l4_event_handler(struct net_mgmt_event_callback *cb,
497 			     uint64_t mgmt_event, struct net_if *iface)
498 {
499 	if ((mgmt_event & L4_EVENT_MASK) != mgmt_event) {
500 		return;
501 	}
502 
503 	if (mgmt_event == NET_EVENT_L4_CONNECTED) {
504 		/* Wait for DHCP to be back in BOUND state */
505 		k_work_reschedule(&check_network_conn, K_SECONDS(3));
506 
507 		return;
508 	}
509 
510 	if (mgmt_event == NET_EVENT_L4_DISCONNECTED) {
511 		abort_mqtt_connection();
512 		k_work_cancel_delayable(&check_network_conn);
513 
514 		return;
515 	}
516 }
517 #endif
518 
main(void)519 int main(void)
520 {
521 	int rc;
522 
523 	LOG_DBG("Waiting for network to setup...");
524 
525 	rc = tls_init();
526 	if (rc) {
527 		return 0;
528 	}
529 
530 	k_work_init_delayable(&pub_message, publish_timeout);
531 
532 #if defined(CONFIG_NET_DHCPV4)
533 	k_work_init_delayable(&check_network_conn, check_network_connection);
534 
535 	net_mgmt_init_event_callback(&l4_mgmt_cb, l4_event_handler,
536 				     L4_EVENT_MASK);
537 	net_mgmt_add_event_callback(&l4_mgmt_cb);
538 #endif
539 
540 	connect_to_cloud_and_publish();
541 	return 0;
542 }
543