1 /*
2  * Copyright (c) 2023 Lucas Dietrich <ld.adecy@gmail.com>
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include "creds/creds.h"
8 
9 #include <errno.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 
13 #include <zephyr/net/socket.h>
14 #include <zephyr/net/dns_resolve.h>
15 #include <zephyr/net/mqtt.h>
16 #include <zephyr/net/tls_credentials.h>
17 #include <zephyr/data/json.h>
18 #include <zephyr/random/random.h>
19 #include <zephyr/logging/log.h>
20 
21 
22 #if defined(CONFIG_MBEDTLS_MEMORY_DEBUG)
23 #include <mbedtls/memory_buffer_alloc.h>
24 #endif
25 
26 LOG_MODULE_REGISTER(aws, LOG_LEVEL_DBG);
27 
28 #define SNTP_SERVER "0.pool.ntp.org"
29 
30 #define AWS_BROKER_PORT CONFIG_AWS_MQTT_PORT
31 
32 #define MQTT_BUFFER_SIZE 256u
33 #define APP_BUFFER_SIZE	 4096u
34 
35 #define MAX_RETRIES	    10u
36 #define BACKOFF_EXP_BASE_MS 1000u
37 #define BACKOFF_EXP_MAX_MS  60000u
38 #define BACKOFF_CONST_MS    5000u
39 
40 static struct sockaddr_in aws_broker;
41 
42 static uint8_t rx_buffer[MQTT_BUFFER_SIZE];
43 static uint8_t tx_buffer[MQTT_BUFFER_SIZE];
44 static uint8_t buffer[APP_BUFFER_SIZE]; /* Shared between published and received messages */
45 
46 static struct mqtt_client client_ctx;
47 
48 static const char mqtt_client_name[] = CONFIG_AWS_THING_NAME;
49 
50 static uint32_t messages_received_counter;
51 static bool do_publish;	  /* Trigger client to publish */
52 static bool do_subscribe; /* Trigger client to subscribe */
53 
54 #if (CONFIG_AWS_MQTT_PORT == 443 && !defined(CONFIG_MQTT_LIB_WEBSOCKET))
55 static const char * const alpn_list[] = {"x-amzn-mqtt-ca"};
56 #endif
57 
58 #define TLS_TAG_DEVICE_CERTIFICATE 1
59 #define TLS_TAG_DEVICE_PRIVATE_KEY 1
60 #define TLS_TAG_AWS_CA_CERTIFICATE 2
61 
62 static const sec_tag_t sec_tls_tags[] = {
63 	TLS_TAG_DEVICE_CERTIFICATE,
64 	TLS_TAG_AWS_CA_CERTIFICATE,
65 };
66 
setup_credentials(void)67 static int setup_credentials(void)
68 {
69 	int ret;
70 
71 	ret = tls_credential_add(TLS_TAG_DEVICE_CERTIFICATE, TLS_CREDENTIAL_SERVER_CERTIFICATE,
72 				 public_cert, public_cert_len);
73 	if (ret < 0) {
74 		LOG_ERR("Failed to add device certificate: %d", ret);
75 		goto exit;
76 	}
77 
78 	ret = tls_credential_add(TLS_TAG_DEVICE_PRIVATE_KEY, TLS_CREDENTIAL_PRIVATE_KEY,
79 				 private_key, private_key_len);
80 	if (ret < 0) {
81 		LOG_ERR("Failed to add device private key: %d", ret);
82 		goto exit;
83 	}
84 
85 	ret = tls_credential_add(TLS_TAG_AWS_CA_CERTIFICATE, TLS_CREDENTIAL_CA_CERTIFICATE, ca_cert,
86 				 ca_cert_len);
87 	if (ret < 0) {
88 		LOG_ERR("Failed to add device private key: %d", ret);
89 		goto exit;
90 	}
91 
92 exit:
93 	return ret;
94 }
95 
subscribe_topic(void)96 static int subscribe_topic(void)
97 {
98 	int ret;
99 	struct mqtt_topic topics[] = {{
100 		.topic = {.utf8 = CONFIG_AWS_SUBSCRIBE_TOPIC,
101 			  .size = strlen(CONFIG_AWS_SUBSCRIBE_TOPIC)},
102 		.qos = CONFIG_AWS_QOS,
103 	}};
104 	const struct mqtt_subscription_list sub_list = {
105 		.list = topics,
106 		.list_count = ARRAY_SIZE(topics),
107 		.message_id = 1u,
108 	};
109 
110 	LOG_INF("Subscribing to %hu topic(s)", sub_list.list_count);
111 
112 	ret = mqtt_subscribe(&client_ctx, &sub_list);
113 	if (ret != 0) {
114 		LOG_ERR("Failed to subscribe to topics: %d", ret);
115 	}
116 
117 	return ret;
118 }
119 
publish_message(const char * topic,size_t topic_len,uint8_t * payload,size_t payload_len)120 static int publish_message(const char *topic, size_t topic_len, uint8_t *payload,
121 			   size_t payload_len)
122 {
123 	static uint32_t message_id = 1u;
124 
125 	int ret;
126 	struct mqtt_publish_param msg;
127 
128 	msg.retain_flag = 0u;
129 	msg.dup_flag = 0u;
130 	msg.message.topic.topic.utf8 = topic;
131 	msg.message.topic.topic.size = topic_len;
132 	msg.message.topic.qos = CONFIG_AWS_QOS;
133 	msg.message.payload.data = payload;
134 	msg.message.payload.len = payload_len;
135 	msg.message_id = message_id++;
136 
137 	ret = mqtt_publish(&client_ctx, &msg);
138 	if (ret != 0) {
139 		LOG_ERR("Failed to publish message: %d", ret);
140 	}
141 
142 	LOG_INF("PUBLISHED on topic \"%s\" [ id: %u qos: %u ], payload: %u B", topic,
143 		msg.message_id, msg.message.topic.qos, payload_len);
144 	LOG_HEXDUMP_DBG(payload, payload_len, "Published payload:");
145 
146 	return ret;
147 }
148 
handle_published_message(const struct mqtt_publish_param * pub)149 static ssize_t handle_published_message(const struct mqtt_publish_param *pub)
150 {
151 	int ret;
152 	size_t received = 0u;
153 	const size_t message_size = pub->message.payload.len;
154 	const bool discarded = message_size > APP_BUFFER_SIZE;
155 
156 	LOG_INF("RECEIVED on topic \"%s\" [ id: %u qos: %u ] payload: %u / %u B",
157 		(const char *)pub->message.topic.topic.utf8, pub->message_id,
158 		pub->message.topic.qos, message_size, APP_BUFFER_SIZE);
159 
160 	while (received < message_size) {
161 		uint8_t *p = discarded ? buffer : &buffer[received];
162 
163 		ret = mqtt_read_publish_payload_blocking(&client_ctx, p, APP_BUFFER_SIZE);
164 		if (ret < 0) {
165 			return ret;
166 		}
167 
168 		received += ret;
169 	}
170 
171 	if (!discarded) {
172 		LOG_HEXDUMP_DBG(buffer, MIN(message_size, 256u), "Received payload:");
173 	}
174 
175 	/* Send ACK */
176 	switch (pub->message.topic.qos) {
177 	case MQTT_QOS_1_AT_LEAST_ONCE: {
178 		struct mqtt_puback_param puback;
179 
180 		puback.message_id = pub->message_id;
181 		mqtt_publish_qos1_ack(&client_ctx, &puback);
182 	} break;
183 	case MQTT_QOS_2_EXACTLY_ONCE: /* unhandled (not supported by AWS) */
184 	case MQTT_QOS_0_AT_MOST_ONCE: /* nothing to do */
185 	default:
186 		break;
187 	}
188 
189 	return discarded ? -ENOMEM : received;
190 }
191 
mqtt_evt_type_to_str(enum mqtt_evt_type type)192 const char *mqtt_evt_type_to_str(enum mqtt_evt_type type)
193 {
194 	static const char *const types[] = {
195 		"CONNACK", "DISCONNECT", "PUBLISH", "PUBACK",	"PUBREC",
196 		"PUBREL",  "PUBCOMP",	 "SUBACK",  "UNSUBACK", "PINGRESP",
197 	};
198 
199 	return (type < ARRAY_SIZE(types)) ? types[type] : "<unknown>";
200 }
201 
mqtt_event_cb(struct mqtt_client * client,const struct mqtt_evt * evt)202 static void mqtt_event_cb(struct mqtt_client *client, const struct mqtt_evt *evt)
203 {
204 	LOG_DBG("MQTT event: %s [%u] result: %d", mqtt_evt_type_to_str(evt->type), evt->type,
205 		evt->result);
206 
207 	switch (evt->type) {
208 	case MQTT_EVT_CONNACK: {
209 		do_subscribe = true;
210 	} break;
211 
212 	case MQTT_EVT_PUBLISH: {
213 		const struct mqtt_publish_param *pub = &evt->param.publish;
214 
215 		handle_published_message(pub);
216 		messages_received_counter++;
217 #if !defined(CONFIG_AWS_TEST_SUITE_RECV_QOS1)
218 		do_publish = true;
219 #endif
220 	} break;
221 
222 	case MQTT_EVT_SUBACK: {
223 #if !defined(CONFIG_AWS_TEST_SUITE_RECV_QOS1)
224 		do_publish = true;
225 #endif
226 	} break;
227 
228 	case MQTT_EVT_PUBACK:
229 	case MQTT_EVT_DISCONNECT:
230 	case MQTT_EVT_PUBREC:
231 	case MQTT_EVT_PUBREL:
232 	case MQTT_EVT_PUBCOMP:
233 	case MQTT_EVT_PINGRESP:
234 	case MQTT_EVT_UNSUBACK:
235 	default:
236 		break;
237 	}
238 }
239 
aws_client_setup(void)240 static void aws_client_setup(void)
241 {
242 	mqtt_client_init(&client_ctx);
243 
244 	client_ctx.broker = &aws_broker;
245 	client_ctx.evt_cb = mqtt_event_cb;
246 
247 	client_ctx.client_id.utf8 = (uint8_t *)mqtt_client_name;
248 	client_ctx.client_id.size = sizeof(mqtt_client_name) - 1;
249 	client_ctx.password = NULL;
250 	client_ctx.user_name = NULL;
251 
252 	client_ctx.keepalive = CONFIG_MQTT_KEEPALIVE;
253 
254 	client_ctx.protocol_version = MQTT_VERSION_3_1_1;
255 
256 	client_ctx.rx_buf = rx_buffer;
257 	client_ctx.rx_buf_size = MQTT_BUFFER_SIZE;
258 	client_ctx.tx_buf = tx_buffer;
259 	client_ctx.tx_buf_size = MQTT_BUFFER_SIZE;
260 
261 	/* setup TLS */
262 	client_ctx.transport.type = MQTT_TRANSPORT_SECURE;
263 	struct mqtt_sec_config *const tls_config = &client_ctx.transport.tls.config;
264 
265 	tls_config->peer_verify = TLS_PEER_VERIFY_REQUIRED;
266 	tls_config->cipher_list = NULL;
267 	tls_config->sec_tag_list = sec_tls_tags;
268 	tls_config->sec_tag_count = ARRAY_SIZE(sec_tls_tags);
269 	tls_config->hostname = CONFIG_AWS_ENDPOINT;
270 	tls_config->cert_nocopy = TLS_CERT_NOCOPY_NONE;
271 #if (CONFIG_AWS_MQTT_PORT == 443 && !defined(CONFIG_MQTT_LIB_WEBSOCKET))
272 	tls_config->alpn_protocol_name_list = alpn_list;
273 	tls_config->alpn_protocol_name_count = ARRAY_SIZE(alpn_list);
274 #endif
275 }
276 
277 struct backoff_context {
278 	uint16_t retries_count;
279 	uint16_t max_retries;
280 
281 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
282 	uint32_t attempt_max_backoff; /* ms */
283 	uint32_t max_backoff;	      /* ms */
284 #endif
285 };
286 
backoff_context_init(struct backoff_context * bo)287 static void backoff_context_init(struct backoff_context *bo)
288 {
289 	__ASSERT_NO_MSG(bo != NULL);
290 
291 	bo->retries_count = 0u;
292 	bo->max_retries = MAX_RETRIES;
293 
294 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
295 	bo->attempt_max_backoff = BACKOFF_EXP_BASE_MS;
296 	bo->max_backoff = BACKOFF_EXP_MAX_MS;
297 #endif
298 }
299 
300 /* https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ */
backoff_get_next(struct backoff_context * bo,uint32_t * next_backoff_ms)301 static void backoff_get_next(struct backoff_context *bo, uint32_t *next_backoff_ms)
302 {
303 	__ASSERT_NO_MSG(bo != NULL);
304 	__ASSERT_NO_MSG(next_backoff_ms != NULL);
305 
306 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
307 	if (bo->retries_count <= bo->max_retries) {
308 		*next_backoff_ms = sys_rand32_get() % (bo->attempt_max_backoff + 1u);
309 
310 		/* Calculate max backoff for the next attempt (~ 2**attempt) */
311 		bo->attempt_max_backoff = MIN(bo->attempt_max_backoff * 2u, bo->max_backoff);
312 		bo->retries_count++;
313 	}
314 #else
315 	*next_backoff_ms = BACKOFF_CONST_MS;
316 #endif
317 }
318 
aws_client_try_connect(void)319 static int aws_client_try_connect(void)
320 {
321 	int ret;
322 	uint32_t backoff_ms;
323 	struct backoff_context bo;
324 
325 	backoff_context_init(&bo);
326 
327 	while (bo.retries_count <= bo.max_retries) {
328 		ret = mqtt_connect(&client_ctx);
329 		if (ret == 0) {
330 			goto exit;
331 		}
332 
333 		backoff_get_next(&bo, &backoff_ms);
334 
335 		LOG_ERR("Failed to connect: %d backoff delay: %u ms", ret, backoff_ms);
336 		k_msleep(backoff_ms);
337 	}
338 
339 exit:
340 	return ret;
341 }
342 
343 struct publish_payload {
344 	uint32_t counter;
345 };
346 
347 static const struct json_obj_descr json_descr[] = {
348 	JSON_OBJ_DESCR_PRIM(struct publish_payload, counter, JSON_TOK_NUMBER),
349 };
350 
publish(void)351 static int publish(void)
352 {
353 	struct publish_payload pl = {.counter = messages_received_counter};
354 
355 	json_obj_encode_buf(json_descr, ARRAY_SIZE(json_descr), &pl, buffer, sizeof(buffer));
356 
357 	return publish_message(CONFIG_AWS_PUBLISH_TOPIC, strlen(CONFIG_AWS_PUBLISH_TOPIC), buffer,
358 			       strlen(buffer));
359 }
360 
aws_client_loop(void)361 void aws_client_loop(void)
362 {
363 	int rc;
364 	int timeout;
365 	struct pollfd fds;
366 
367 	aws_client_setup();
368 
369 	rc = aws_client_try_connect();
370 	if (rc != 0) {
371 		goto cleanup;
372 	}
373 
374 	fds.fd = client_ctx.transport.tcp.sock;
375 	fds.events = POLLIN;
376 
377 	for (;;) {
378 		timeout = mqtt_keepalive_time_left(&client_ctx);
379 		rc = poll(&fds, 1u, timeout);
380 		if (rc >= 0) {
381 			if (fds.revents & POLLIN) {
382 				rc = mqtt_input(&client_ctx);
383 				if (rc != 0) {
384 					LOG_ERR("Failed to read MQTT input: %d", rc);
385 					break;
386 				}
387 			}
388 
389 			if (fds.revents & (POLLHUP | POLLERR)) {
390 				LOG_ERR("Socket closed/error");
391 				break;
392 			}
393 
394 			rc = mqtt_live(&client_ctx);
395 			if ((rc != 0) && (rc != -EAGAIN)) {
396 				LOG_ERR("Failed to live MQTT: %d", rc);
397 				break;
398 			}
399 		} else {
400 			LOG_ERR("poll failed: %d", rc);
401 			break;
402 		}
403 
404 		if (do_publish) {
405 			do_publish = false;
406 			publish();
407 		}
408 
409 		if (do_subscribe) {
410 			do_subscribe = false;
411 			subscribe_topic();
412 		}
413 	}
414 
415 cleanup:
416 	mqtt_disconnect(&client_ctx);
417 
418 	close(fds.fd);
419 	fds.fd = -1;
420 }
421 
resolve_broker_addr(struct sockaddr_in * broker)422 static int resolve_broker_addr(struct sockaddr_in *broker)
423 {
424 	int ret;
425 	struct addrinfo *ai = NULL;
426 
427 	const struct addrinfo hints = {
428 		.ai_family = AF_INET,
429 		.ai_socktype = SOCK_STREAM,
430 		.ai_protocol = 0,
431 	};
432 	char port_string[6] = {0};
433 
434 	sprintf(port_string, "%d", AWS_BROKER_PORT);
435 	ret = getaddrinfo(CONFIG_AWS_ENDPOINT, port_string, &hints, &ai);
436 	if (ret == 0) {
437 		char addr_str[INET_ADDRSTRLEN];
438 
439 		memcpy(broker, ai->ai_addr, MIN(ai->ai_addrlen, sizeof(struct sockaddr_storage)));
440 
441 		inet_ntop(AF_INET, &broker->sin_addr, addr_str, sizeof(addr_str));
442 		LOG_INF("Resolved: %s:%u", addr_str, htons(broker->sin_port));
443 	} else {
444 		LOG_ERR("failed to resolve hostname err = %d (errno = %d)", ret, errno);
445 	}
446 
447 	freeaddrinfo(ai);
448 
449 	return ret;
450 }
451 
main(void)452 int main(void)
453 {
454 	setup_credentials();
455 
456 	for (;;) {
457 		resolve_broker_addr(&aws_broker);
458 
459 		aws_client_loop();
460 
461 #if defined(CONFIG_MBEDTLS_MEMORY_DEBUG)
462 		size_t cur_used, cur_blocks, max_used, max_blocks;
463 
464 		mbedtls_memory_buffer_alloc_cur_get(&cur_used, &cur_blocks);
465 		mbedtls_memory_buffer_alloc_max_get(&max_used, &max_blocks);
466 		LOG_INF("mbedTLS heap usage: MAX %u/%u (%u) CUR %u (%u)", max_used,
467 			CONFIG_MBEDTLS_HEAP_SIZE, max_blocks, cur_used, cur_blocks);
468 #endif
469 
470 		k_sleep(K_SECONDS(1));
471 	}
472 
473 	return 0;
474 }
475