1 /*
2  * Copyright (c) 2019 Intel Corporation
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/logging/log.h>
8 LOG_MODULE_REGISTER(mqtt_azure, LOG_LEVEL_DBG);
9 
10 #include <zephyr/kernel.h>
11 #include <zephyr/net/net_if.h>
12 #include <zephyr/net/socket.h>
13 #include <zephyr/net/mqtt.h>
14 
15 #include <zephyr/sys/printk.h>
16 #include <zephyr/random/random.h>
17 #include <string.h>
18 #include <errno.h>
19 
20 #include "config.h"
21 #include "test_certs.h"
22 
23 /* Buffers for MQTT client. */
24 static uint8_t rx_buffer[APP_MQTT_BUFFER_SIZE];
25 static uint8_t tx_buffer[APP_MQTT_BUFFER_SIZE];
26 
27 /* The mqtt client struct */
28 static struct mqtt_client client_ctx;
29 
30 /* MQTT Broker details. */
31 static struct sockaddr_storage broker;
32 
33 #if defined(CONFIG_SOCKS)
34 static struct sockaddr socks5_proxy;
35 #endif
36 
37 /* Socket Poll */
38 static struct pollfd fds[1];
39 static int nfds;
40 
41 static bool mqtt_connected;
42 
43 static struct k_work_delayable pub_message;
44 #if defined(CONFIG_NET_DHCPV4)
45 static struct k_work_delayable check_network_conn;
46 
47 /* Network Management events */
48 #define L4_EVENT_MASK (NET_EVENT_L4_CONNECTED | NET_EVENT_L4_DISCONNECTED)
49 
50 static struct net_mgmt_event_callback l4_mgmt_cb;
51 #endif
52 
53 #if defined(CONFIG_DNS_RESOLVER)
54 static struct addrinfo hints;
55 static struct addrinfo *haddr;
56 #endif
57 
58 static K_SEM_DEFINE(mqtt_start, 0, 1);
59 
60 /* Application TLS configuration details */
61 #define TLS_SNI_HOSTNAME CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME
62 #define APP_CA_CERT_TAG 1
63 
64 static const sec_tag_t m_sec_tags[] = {
65 	APP_CA_CERT_TAG,
66 };
67 
68 static uint8_t devbound_topic[] = "devices/" MQTT_CLIENTID "/messages/devicebound/#";
69 static struct mqtt_topic subs_topic;
70 static struct mqtt_subscription_list subs_list;
71 
72 static void mqtt_event_handler(struct mqtt_client *const client,
73 			       const struct mqtt_evt *evt);
74 
tls_init(void)75 static int tls_init(void)
76 {
77 	int err;
78 
79 	err = tls_credential_add(APP_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE,
80 				 ca_certificate, sizeof(ca_certificate));
81 	if (err < 0) {
82 		LOG_ERR("Failed to register public certificate: %d", err);
83 		return err;
84 	}
85 
86 	return err;
87 }
88 
prepare_fds(struct mqtt_client * client)89 static void prepare_fds(struct mqtt_client *client)
90 {
91 	if (client->transport.type == MQTT_TRANSPORT_SECURE) {
92 		fds[0].fd = client->transport.tls.sock;
93 	}
94 
95 	fds[0].events = POLLIN;
96 	nfds = 1;
97 }
98 
clear_fds(void)99 static void clear_fds(void)
100 {
101 	nfds = 0;
102 }
103 
wait(int timeout)104 static int wait(int timeout)
105 {
106 	int rc = -EINVAL;
107 
108 	if (nfds <= 0) {
109 		return rc;
110 	}
111 
112 	rc = poll(fds, nfds, timeout);
113 	if (rc < 0) {
114 		LOG_ERR("poll error: %d", errno);
115 		return -errno;
116 	}
117 
118 	return rc;
119 }
120 
broker_init(void)121 static void broker_init(void)
122 {
123 	struct sockaddr_in *broker4 = (struct sockaddr_in *)&broker;
124 
125 	broker4->sin_family = AF_INET;
126 	broker4->sin_port = htons(SERVER_PORT);
127 
128 #if defined(CONFIG_DNS_RESOLVER)
129 	net_ipaddr_copy(&broker4->sin_addr,
130 			&net_sin(haddr->ai_addr)->sin_addr);
131 #else
132 	inet_pton(AF_INET, SERVER_ADDR, &broker4->sin_addr);
133 #endif
134 
135 #if defined(CONFIG_SOCKS)
136 	struct sockaddr_in *proxy4 = (struct sockaddr_in *)&socks5_proxy;
137 
138 	proxy4->sin_family = AF_INET;
139 	proxy4->sin_port = htons(SOCKS5_PROXY_PORT);
140 	inet_pton(AF_INET, SOCKS5_PROXY_ADDR, &proxy4->sin_addr);
141 #endif
142 }
143 
client_init(struct mqtt_client * client)144 static void client_init(struct mqtt_client *client)
145 {
146 	static struct mqtt_utf8 password;
147 	static struct mqtt_utf8 username;
148 	struct mqtt_sec_config *tls_config;
149 
150 	mqtt_client_init(client);
151 
152 	broker_init();
153 
154 	/* MQTT client configuration */
155 	client->broker = &broker;
156 	client->evt_cb = mqtt_event_handler;
157 
158 	client->client_id.utf8 = (uint8_t *)MQTT_CLIENTID;
159 	client->client_id.size = strlen(MQTT_CLIENTID);
160 
161 	password.utf8 = (uint8_t *)CONFIG_SAMPLE_CLOUD_AZURE_PASSWORD;
162 	password.size = strlen(CONFIG_SAMPLE_CLOUD_AZURE_PASSWORD);
163 
164 	client->password = &password;
165 
166 	username.utf8 = (uint8_t *)CONFIG_SAMPLE_CLOUD_AZURE_USERNAME;
167 	username.size = strlen(CONFIG_SAMPLE_CLOUD_AZURE_USERNAME);
168 
169 	client->user_name = &username;
170 
171 	client->protocol_version = MQTT_VERSION_3_1_1;
172 
173 	/* MQTT buffers configuration */
174 	client->rx_buf = rx_buffer;
175 	client->rx_buf_size = sizeof(rx_buffer);
176 	client->tx_buf = tx_buffer;
177 	client->tx_buf_size = sizeof(tx_buffer);
178 
179 	/* MQTT transport configuration */
180 	client->transport.type = MQTT_TRANSPORT_SECURE;
181 
182 	tls_config = &client->transport.tls.config;
183 
184 	tls_config->peer_verify = TLS_PEER_VERIFY_REQUIRED;
185 	tls_config->cipher_list = NULL;
186 	tls_config->sec_tag_list = m_sec_tags;
187 	tls_config->sec_tag_count = ARRAY_SIZE(m_sec_tags);
188 	tls_config->hostname = TLS_SNI_HOSTNAME;
189 
190 #if defined(CONFIG_SOCKS)
191 	mqtt_client_set_proxy(client, &socks5_proxy,
192 			      socks5_proxy.sa_family == AF_INET ?
193 			      sizeof(struct sockaddr_in) :
194 			      sizeof(struct sockaddr_in6));
195 #endif
196 }
197 
mqtt_event_handler(struct mqtt_client * const client,const struct mqtt_evt * evt)198 static void mqtt_event_handler(struct mqtt_client *const client,
199 			       const struct mqtt_evt *evt)
200 {
201 	struct mqtt_puback_param puback;
202 	uint8_t data[33];
203 	int len;
204 	int bytes_read;
205 
206 	switch (evt->type) {
207 	case MQTT_EVT_SUBACK:
208 		LOG_INF("SUBACK packet id: %u", evt->param.suback.message_id);
209 		break;
210 
211 	case MQTT_EVT_UNSUBACK:
212 		LOG_INF("UNSUBACK packet id: %u", evt->param.suback.message_id);
213 		break;
214 
215 	case MQTT_EVT_CONNACK:
216 		if (evt->result) {
217 			LOG_ERR("MQTT connect failed %d", evt->result);
218 			break;
219 		}
220 
221 		mqtt_connected = true;
222 		LOG_DBG("MQTT client connected!");
223 		break;
224 
225 	case MQTT_EVT_DISCONNECT:
226 		LOG_DBG("MQTT client disconnected %d", evt->result);
227 
228 		mqtt_connected = false;
229 		clear_fds();
230 		break;
231 
232 	case MQTT_EVT_PUBACK:
233 		if (evt->result) {
234 			LOG_ERR("MQTT PUBACK error %d", evt->result);
235 			break;
236 		}
237 
238 		LOG_DBG("PUBACK packet id: %u\n", evt->param.puback.message_id);
239 		break;
240 
241 	case MQTT_EVT_PUBLISH:
242 		len = evt->param.publish.message.payload.len;
243 
244 		LOG_INF("MQTT publish received %d, %d bytes", evt->result, len);
245 		LOG_INF(" id: %d, qos: %d", evt->param.publish.message_id,
246 			evt->param.publish.message.topic.qos);
247 
248 		while (len) {
249 			bytes_read = mqtt_read_publish_payload(&client_ctx,
250 					data,
251 					len >= sizeof(data) - 1 ?
252 					sizeof(data) - 1 : len);
253 			if (bytes_read < 0 && bytes_read != -EAGAIN) {
254 				LOG_ERR("failure to read payload");
255 				break;
256 			}
257 
258 			data[bytes_read] = '\0';
259 			LOG_INF("   payload: %s", data);
260 			len -= bytes_read;
261 		}
262 
263 		puback.message_id = evt->param.publish.message_id;
264 		mqtt_publish_qos1_ack(&client_ctx, &puback);
265 		break;
266 
267 	default:
268 		LOG_DBG("Unhandled MQTT event %d", evt->type);
269 		break;
270 	}
271 }
272 
subscribe(struct mqtt_client * client)273 static void subscribe(struct mqtt_client *client)
274 {
275 	int err;
276 
277 	/* subscribe */
278 	subs_topic.topic.utf8 = devbound_topic;
279 	subs_topic.topic.size = strlen(devbound_topic);
280 	subs_list.list = &subs_topic;
281 	subs_list.list_count = 1U;
282 	subs_list.message_id = 1U;
283 
284 	err = mqtt_subscribe(client, &subs_list);
285 	if (err) {
286 		LOG_ERR("Failed on topic %s", devbound_topic);
287 	}
288 }
289 
publish(struct mqtt_client * client,enum mqtt_qos qos)290 static int publish(struct mqtt_client *client, enum mqtt_qos qos)
291 {
292 	char payload[] = "{id=123}";
293 	char evt_topic[] = "devices/" MQTT_CLIENTID "/messages/events/";
294 	uint8_t len = strlen(evt_topic);
295 	struct mqtt_publish_param param;
296 
297 	param.message.topic.qos = qos;
298 	param.message.topic.topic.utf8 = (uint8_t *)evt_topic;
299 	param.message.topic.topic.size = len;
300 	param.message.payload.data = payload;
301 	param.message.payload.len = strlen(payload);
302 	param.message_id = sys_rand16_get();
303 	param.dup_flag = 0U;
304 	param.retain_flag = 0U;
305 
306 	return mqtt_publish(client, &param);
307 }
308 
poll_mqtt(void)309 static void poll_mqtt(void)
310 {
311 	int rc;
312 
313 	while (mqtt_connected) {
314 		rc = wait(SYS_FOREVER_MS);
315 		if (rc > 0) {
316 			mqtt_input(&client_ctx);
317 		}
318 	}
319 }
320 
321 /* Random time between 10 - 15 seconds
322  * If you prefer to have this value more than CONFIG_MQTT_KEEPALIVE,
323  * then keep the application connection live by calling mqtt_live()
324  * in regular intervals.
325  */
timeout_for_publish(void)326 static uint8_t timeout_for_publish(void)
327 {
328 	return (10 + sys_rand8_get() % 5);
329 }
330 
publish_timeout(struct k_work * work)331 static void publish_timeout(struct k_work *work)
332 {
333 	int rc;
334 
335 	if (!mqtt_connected) {
336 		return;
337 	}
338 
339 	rc = publish(&client_ctx, MQTT_QOS_1_AT_LEAST_ONCE);
340 	if (rc) {
341 		LOG_ERR("mqtt_publish ERROR");
342 		goto end;
343 	}
344 
345 	LOG_DBG("mqtt_publish OK");
346 end:
347 	k_work_reschedule(&pub_message, K_SECONDS(timeout_for_publish()));
348 }
349 
try_to_connect(struct mqtt_client * client)350 static int try_to_connect(struct mqtt_client *client)
351 {
352 	uint8_t retries = 3U;
353 	int rc;
354 
355 	LOG_DBG("attempting to connect...");
356 
357 	while (retries--) {
358 		client_init(client);
359 
360 		rc = mqtt_connect(client);
361 		if (rc) {
362 			LOG_ERR("mqtt_connect failed %d", rc);
363 			continue;
364 		}
365 
366 		prepare_fds(client);
367 
368 		rc = wait(APP_SLEEP_MSECS);
369 		if (rc < 0) {
370 			mqtt_abort(client);
371 			return rc;
372 		}
373 
374 		mqtt_input(client);
375 
376 		if (mqtt_connected) {
377 			subscribe(client);
378 			k_work_reschedule(&pub_message,
379 					  K_SECONDS(timeout_for_publish()));
380 			return 0;
381 		}
382 
383 		mqtt_abort(client);
384 
385 		wait(10 * MSEC_PER_SEC);
386 	}
387 
388 	return -EINVAL;
389 }
390 
391 #if defined(CONFIG_DNS_RESOLVER)
get_mqtt_broker_addrinfo(void)392 static int get_mqtt_broker_addrinfo(void)
393 {
394 	int retries = 3;
395 	int rc = -EINVAL;
396 
397 	while (retries--) {
398 		hints.ai_family = AF_INET;
399 		hints.ai_socktype = SOCK_STREAM;
400 		hints.ai_protocol = 0;
401 
402 		rc = getaddrinfo(CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME, "8883",
403 				 &hints, &haddr);
404 		if (rc == 0) {
405 			LOG_INF("DNS resolved for %s:%d",
406 			CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME,
407 			CONFIG_SAMPLE_CLOUD_AZURE_SERVER_PORT);
408 
409 			return 0;
410 		}
411 
412 		LOG_ERR("DNS not resolved for %s:%d, retrying",
413 			CONFIG_SAMPLE_CLOUD_AZURE_HOSTNAME,
414 			CONFIG_SAMPLE_CLOUD_AZURE_SERVER_PORT);
415 	}
416 
417 	return rc;
418 }
419 #endif
420 
connect_to_cloud_and_publish(void)421 static void connect_to_cloud_and_publish(void)
422 {
423 	int rc = -EINVAL;
424 
425 #if defined(CONFIG_NET_DHCPV4)
426 	while (true) {
427 		k_sem_take(&mqtt_start, K_FOREVER);
428 #endif
429 #if defined(CONFIG_DNS_RESOLVER)
430 		rc = get_mqtt_broker_addrinfo();
431 		if (rc) {
432 			return;
433 		}
434 #endif
435 		rc = try_to_connect(&client_ctx);
436 		if (rc) {
437 			return;
438 		}
439 
440 		poll_mqtt();
441 #if defined(CONFIG_NET_DHCPV4)
442 	}
443 #endif
444 }
445 
446 /* DHCP tries to renew the address after interface is down and up.
447  * If DHCPv4 address renewal is success, then it doesn't generate
448  * any event. We have to monitor this way.
449  * If DHCPv4 attempts exceeds maximum number, it will delete iface
450  * address and attempts for new request. In this case we can rely
451  * on IPV4_ADDR_ADD event.
452  */
453 #if defined(CONFIG_NET_DHCPV4)
check_network_connection(struct k_work * work)454 static void check_network_connection(struct k_work *work)
455 {
456 	struct net_if *iface;
457 
458 	if (mqtt_connected) {
459 		return;
460 	}
461 
462 	iface = net_if_get_default();
463 	if (!iface) {
464 		goto end;
465 	}
466 
467 	if (iface->config.dhcpv4.state == NET_DHCPV4_BOUND) {
468 		k_sem_give(&mqtt_start);
469 		return;
470 	}
471 
472 	LOG_INF("waiting for DHCP to acquire addr");
473 
474 end:
475 	k_work_reschedule(&check_network_conn, K_SECONDS(3));
476 }
477 #endif
478 
479 #if defined(CONFIG_NET_DHCPV4)
abort_mqtt_connection(void)480 static void abort_mqtt_connection(void)
481 {
482 	if (mqtt_connected) {
483 		mqtt_connected = false;
484 		mqtt_abort(&client_ctx);
485 		k_work_cancel_delayable(&pub_message);
486 	}
487 }
488 
l4_event_handler(struct net_mgmt_event_callback * cb,uint32_t mgmt_event,struct net_if * iface)489 static void l4_event_handler(struct net_mgmt_event_callback *cb,
490 			     uint32_t mgmt_event, struct net_if *iface)
491 {
492 	if ((mgmt_event & L4_EVENT_MASK) != mgmt_event) {
493 		return;
494 	}
495 
496 	if (mgmt_event == NET_EVENT_L4_CONNECTED) {
497 		/* Wait for DHCP to be back in BOUND state */
498 		k_work_reschedule(&check_network_conn, K_SECONDS(3));
499 
500 		return;
501 	}
502 
503 	if (mgmt_event == NET_EVENT_L4_DISCONNECTED) {
504 		abort_mqtt_connection();
505 		k_work_cancel_delayable(&check_network_conn);
506 
507 		return;
508 	}
509 }
510 #endif
511 
main(void)512 int main(void)
513 {
514 	int rc;
515 
516 	LOG_DBG("Waiting for network to setup...");
517 
518 	rc = tls_init();
519 	if (rc) {
520 		return 0;
521 	}
522 
523 	k_work_init_delayable(&pub_message, publish_timeout);
524 
525 #if defined(CONFIG_NET_DHCPV4)
526 	k_work_init_delayable(&check_network_conn, check_network_connection);
527 
528 	net_mgmt_init_event_callback(&l4_mgmt_cb, l4_event_handler,
529 				     L4_EVENT_MASK);
530 	net_mgmt_add_event_callback(&l4_mgmt_cb);
531 #endif
532 
533 	connect_to_cloud_and_publish();
534 	return 0;
535 }
536