1 /*
2 * Copyright (c) 2023 Lucas Dietrich <ld.adecy@gmail.com>
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include "creds/creds.h"
8
9 #include <errno.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12
13 #include <zephyr/net/socket.h>
14 #include <zephyr/net/dns_resolve.h>
15 #include <zephyr/net/mqtt.h>
16 #include <zephyr/net/tls_credentials.h>
17 #include <zephyr/data/json.h>
18 #include <zephyr/random/random.h>
19 #include <zephyr/logging/log.h>
20
21
22 #if defined(CONFIG_MBEDTLS_MEMORY_DEBUG)
23 #include <mbedtls/memory_buffer_alloc.h>
24 #endif
25
26 LOG_MODULE_REGISTER(aws, LOG_LEVEL_DBG);
27
28 #define SNTP_SERVER "0.pool.ntp.org"
29
30 #define AWS_BROKER_PORT CONFIG_AWS_MQTT_PORT
31
32 #define MQTT_BUFFER_SIZE 256u
33 #define APP_BUFFER_SIZE 4096u
34
35 #define MAX_RETRIES 10u
36 #define BACKOFF_EXP_BASE_MS 1000u
37 #define BACKOFF_EXP_MAX_MS 60000u
38 #define BACKOFF_CONST_MS 5000u
39
40 static struct sockaddr_in aws_broker;
41
42 static uint8_t rx_buffer[MQTT_BUFFER_SIZE];
43 static uint8_t tx_buffer[MQTT_BUFFER_SIZE];
44 static uint8_t buffer[APP_BUFFER_SIZE]; /* Shared between published and received messages */
45
46 static struct mqtt_client client_ctx;
47
48 static const char mqtt_client_name[] = CONFIG_AWS_THING_NAME;
49
50 static uint32_t messages_received_counter;
51 static bool do_publish; /* Trigger client to publish */
52 static bool do_subscribe; /* Trigger client to subscribe */
53
54 #if (CONFIG_AWS_MQTT_PORT == 443 && !defined(CONFIG_MQTT_LIB_WEBSOCKET))
55 static const char * const alpn_list[] = {"x-amzn-mqtt-ca"};
56 #endif
57
58 #define TLS_TAG_DEVICE_CERTIFICATE 1
59 #define TLS_TAG_DEVICE_PRIVATE_KEY 1
60 #define TLS_TAG_AWS_CA_CERTIFICATE 2
61
62 static const sec_tag_t sec_tls_tags[] = {
63 TLS_TAG_DEVICE_CERTIFICATE,
64 TLS_TAG_AWS_CA_CERTIFICATE,
65 };
66
setup_credentials(void)67 static int setup_credentials(void)
68 {
69 int ret;
70
71 ret = tls_credential_add(TLS_TAG_DEVICE_CERTIFICATE, TLS_CREDENTIAL_SERVER_CERTIFICATE,
72 public_cert, public_cert_len);
73 if (ret < 0) {
74 LOG_ERR("Failed to add device certificate: %d", ret);
75 goto exit;
76 }
77
78 ret = tls_credential_add(TLS_TAG_DEVICE_PRIVATE_KEY, TLS_CREDENTIAL_PRIVATE_KEY,
79 private_key, private_key_len);
80 if (ret < 0) {
81 LOG_ERR("Failed to add device private key: %d", ret);
82 goto exit;
83 }
84
85 ret = tls_credential_add(TLS_TAG_AWS_CA_CERTIFICATE, TLS_CREDENTIAL_CA_CERTIFICATE, ca_cert,
86 ca_cert_len);
87 if (ret < 0) {
88 LOG_ERR("Failed to add device private key: %d", ret);
89 goto exit;
90 }
91
92 exit:
93 return ret;
94 }
95
subscribe_topic(void)96 static int subscribe_topic(void)
97 {
98 int ret;
99 struct mqtt_topic topics[] = {{
100 .topic = {.utf8 = CONFIG_AWS_SUBSCRIBE_TOPIC,
101 .size = strlen(CONFIG_AWS_SUBSCRIBE_TOPIC)},
102 .qos = CONFIG_AWS_QOS,
103 }};
104 const struct mqtt_subscription_list sub_list = {
105 .list = topics,
106 .list_count = ARRAY_SIZE(topics),
107 .message_id = 1u,
108 };
109
110 LOG_INF("Subscribing to %hu topic(s)", sub_list.list_count);
111
112 ret = mqtt_subscribe(&client_ctx, &sub_list);
113 if (ret != 0) {
114 LOG_ERR("Failed to subscribe to topics: %d", ret);
115 }
116
117 return ret;
118 }
119
publish_message(const char * topic,size_t topic_len,uint8_t * payload,size_t payload_len)120 static int publish_message(const char *topic, size_t topic_len, uint8_t *payload,
121 size_t payload_len)
122 {
123 static uint32_t message_id = 1u;
124
125 int ret;
126 struct mqtt_publish_param msg;
127
128 msg.retain_flag = 0u;
129 msg.dup_flag = 0u;
130 msg.message.topic.topic.utf8 = topic;
131 msg.message.topic.topic.size = topic_len;
132 msg.message.topic.qos = CONFIG_AWS_QOS;
133 msg.message.payload.data = payload;
134 msg.message.payload.len = payload_len;
135 msg.message_id = message_id++;
136
137 ret = mqtt_publish(&client_ctx, &msg);
138 if (ret != 0) {
139 LOG_ERR("Failed to publish message: %d", ret);
140 }
141
142 LOG_INF("PUBLISHED on topic \"%s\" [ id: %u qos: %u ], payload: %u B", topic,
143 msg.message_id, msg.message.topic.qos, payload_len);
144 LOG_HEXDUMP_DBG(payload, payload_len, "Published payload:");
145
146 return ret;
147 }
148
handle_published_message(const struct mqtt_publish_param * pub)149 static ssize_t handle_published_message(const struct mqtt_publish_param *pub)
150 {
151 int ret;
152 size_t received = 0u;
153 const size_t message_size = pub->message.payload.len;
154 const bool discarded = message_size > APP_BUFFER_SIZE;
155
156 LOG_INF("RECEIVED on topic \"%s\" [ id: %u qos: %u ] payload: %u / %u B",
157 (const char *)pub->message.topic.topic.utf8, pub->message_id,
158 pub->message.topic.qos, message_size, APP_BUFFER_SIZE);
159
160 while (received < message_size) {
161 uint8_t *p = discarded ? buffer : &buffer[received];
162
163 ret = mqtt_read_publish_payload_blocking(&client_ctx, p, APP_BUFFER_SIZE);
164 if (ret < 0) {
165 return ret;
166 }
167
168 received += ret;
169 }
170
171 if (!discarded) {
172 LOG_HEXDUMP_DBG(buffer, MIN(message_size, 256u), "Received payload:");
173 }
174
175 /* Send ACK */
176 switch (pub->message.topic.qos) {
177 case MQTT_QOS_1_AT_LEAST_ONCE: {
178 struct mqtt_puback_param puback;
179
180 puback.message_id = pub->message_id;
181 mqtt_publish_qos1_ack(&client_ctx, &puback);
182 } break;
183 case MQTT_QOS_2_EXACTLY_ONCE: /* unhandled (not supported by AWS) */
184 case MQTT_QOS_0_AT_MOST_ONCE: /* nothing to do */
185 default:
186 break;
187 }
188
189 return discarded ? -ENOMEM : received;
190 }
191
mqtt_evt_type_to_str(enum mqtt_evt_type type)192 const char *mqtt_evt_type_to_str(enum mqtt_evt_type type)
193 {
194 static const char *const types[] = {
195 "CONNACK", "DISCONNECT", "PUBLISH", "PUBACK", "PUBREC",
196 "PUBREL", "PUBCOMP", "SUBACK", "UNSUBACK", "PINGRESP",
197 };
198
199 return (type < ARRAY_SIZE(types)) ? types[type] : "<unknown>";
200 }
201
mqtt_event_cb(struct mqtt_client * client,const struct mqtt_evt * evt)202 static void mqtt_event_cb(struct mqtt_client *client, const struct mqtt_evt *evt)
203 {
204 LOG_DBG("MQTT event: %s [%u] result: %d", mqtt_evt_type_to_str(evt->type), evt->type,
205 evt->result);
206
207 switch (evt->type) {
208 case MQTT_EVT_CONNACK: {
209 do_subscribe = true;
210 } break;
211
212 case MQTT_EVT_PUBLISH: {
213 const struct mqtt_publish_param *pub = &evt->param.publish;
214
215 handle_published_message(pub);
216 messages_received_counter++;
217 #if !defined(CONFIG_AWS_TEST_SUITE_RECV_QOS1)
218 do_publish = true;
219 #endif
220 } break;
221
222 case MQTT_EVT_SUBACK: {
223 #if !defined(CONFIG_AWS_TEST_SUITE_RECV_QOS1)
224 do_publish = true;
225 #endif
226 } break;
227
228 case MQTT_EVT_PUBACK:
229 case MQTT_EVT_DISCONNECT:
230 case MQTT_EVT_PUBREC:
231 case MQTT_EVT_PUBREL:
232 case MQTT_EVT_PUBCOMP:
233 case MQTT_EVT_PINGRESP:
234 case MQTT_EVT_UNSUBACK:
235 default:
236 break;
237 }
238 }
239
aws_client_setup(void)240 static void aws_client_setup(void)
241 {
242 mqtt_client_init(&client_ctx);
243
244 client_ctx.broker = &aws_broker;
245 client_ctx.evt_cb = mqtt_event_cb;
246
247 client_ctx.client_id.utf8 = (uint8_t *)mqtt_client_name;
248 client_ctx.client_id.size = sizeof(mqtt_client_name) - 1;
249 client_ctx.password = NULL;
250 client_ctx.user_name = NULL;
251
252 client_ctx.keepalive = CONFIG_MQTT_KEEPALIVE;
253
254 client_ctx.protocol_version = MQTT_VERSION_3_1_1;
255
256 client_ctx.rx_buf = rx_buffer;
257 client_ctx.rx_buf_size = MQTT_BUFFER_SIZE;
258 client_ctx.tx_buf = tx_buffer;
259 client_ctx.tx_buf_size = MQTT_BUFFER_SIZE;
260
261 /* setup TLS */
262 client_ctx.transport.type = MQTT_TRANSPORT_SECURE;
263 struct mqtt_sec_config *const tls_config = &client_ctx.transport.tls.config;
264
265 tls_config->peer_verify = TLS_PEER_VERIFY_REQUIRED;
266 tls_config->cipher_list = NULL;
267 tls_config->sec_tag_list = sec_tls_tags;
268 tls_config->sec_tag_count = ARRAY_SIZE(sec_tls_tags);
269 tls_config->hostname = CONFIG_AWS_ENDPOINT;
270 tls_config->cert_nocopy = TLS_CERT_NOCOPY_NONE;
271 #if (CONFIG_AWS_MQTT_PORT == 443 && !defined(CONFIG_MQTT_LIB_WEBSOCKET))
272 tls_config->alpn_protocol_name_list = alpn_list;
273 tls_config->alpn_protocol_name_count = ARRAY_SIZE(alpn_list);
274 #endif
275 }
276
277 struct backoff_context {
278 uint16_t retries_count;
279 uint16_t max_retries;
280
281 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
282 uint32_t attempt_max_backoff; /* ms */
283 uint32_t max_backoff; /* ms */
284 #endif
285 };
286
backoff_context_init(struct backoff_context * bo)287 static void backoff_context_init(struct backoff_context *bo)
288 {
289 __ASSERT_NO_MSG(bo != NULL);
290
291 bo->retries_count = 0u;
292 bo->max_retries = MAX_RETRIES;
293
294 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
295 bo->attempt_max_backoff = BACKOFF_EXP_BASE_MS;
296 bo->max_backoff = BACKOFF_EXP_MAX_MS;
297 #endif
298 }
299
300 /* https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ */
backoff_get_next(struct backoff_context * bo,uint32_t * next_backoff_ms)301 static void backoff_get_next(struct backoff_context *bo, uint32_t *next_backoff_ms)
302 {
303 __ASSERT_NO_MSG(bo != NULL);
304 __ASSERT_NO_MSG(next_backoff_ms != NULL);
305
306 #if defined(CONFIG_AWS_EXPONENTIAL_BACKOFF)
307 if (bo->retries_count <= bo->max_retries) {
308 *next_backoff_ms = sys_rand32_get() % (bo->attempt_max_backoff + 1u);
309
310 /* Calculate max backoff for the next attempt (~ 2**attempt) */
311 bo->attempt_max_backoff = MIN(bo->attempt_max_backoff * 2u, bo->max_backoff);
312 bo->retries_count++;
313 }
314 #else
315 *next_backoff_ms = BACKOFF_CONST_MS;
316 #endif
317 }
318
aws_client_try_connect(void)319 static int aws_client_try_connect(void)
320 {
321 int ret;
322 uint32_t backoff_ms;
323 struct backoff_context bo;
324
325 backoff_context_init(&bo);
326
327 while (bo.retries_count <= bo.max_retries) {
328 ret = mqtt_connect(&client_ctx);
329 if (ret == 0) {
330 goto exit;
331 }
332
333 backoff_get_next(&bo, &backoff_ms);
334
335 LOG_ERR("Failed to connect: %d backoff delay: %u ms", ret, backoff_ms);
336 k_msleep(backoff_ms);
337 }
338
339 exit:
340 return ret;
341 }
342
343 struct publish_payload {
344 uint32_t counter;
345 };
346
347 static const struct json_obj_descr json_descr[] = {
348 JSON_OBJ_DESCR_PRIM(struct publish_payload, counter, JSON_TOK_NUMBER),
349 };
350
publish(void)351 static int publish(void)
352 {
353 struct publish_payload pl = {.counter = messages_received_counter};
354
355 json_obj_encode_buf(json_descr, ARRAY_SIZE(json_descr), &pl, buffer, sizeof(buffer));
356
357 return publish_message(CONFIG_AWS_PUBLISH_TOPIC, strlen(CONFIG_AWS_PUBLISH_TOPIC), buffer,
358 strlen(buffer));
359 }
360
aws_client_loop(void)361 void aws_client_loop(void)
362 {
363 int rc;
364 int timeout;
365 struct pollfd fds;
366
367 aws_client_setup();
368
369 rc = aws_client_try_connect();
370 if (rc != 0) {
371 goto cleanup;
372 }
373
374 fds.fd = client_ctx.transport.tcp.sock;
375 fds.events = POLLIN;
376
377 for (;;) {
378 timeout = mqtt_keepalive_time_left(&client_ctx);
379 rc = poll(&fds, 1u, timeout);
380 if (rc >= 0) {
381 if (fds.revents & POLLIN) {
382 rc = mqtt_input(&client_ctx);
383 if (rc != 0) {
384 LOG_ERR("Failed to read MQTT input: %d", rc);
385 break;
386 }
387 }
388
389 if (fds.revents & (POLLHUP | POLLERR)) {
390 LOG_ERR("Socket closed/error");
391 break;
392 }
393
394 rc = mqtt_live(&client_ctx);
395 if ((rc != 0) && (rc != -EAGAIN)) {
396 LOG_ERR("Failed to live MQTT: %d", rc);
397 break;
398 }
399 } else {
400 LOG_ERR("poll failed: %d", rc);
401 break;
402 }
403
404 if (do_publish) {
405 do_publish = false;
406 publish();
407 }
408
409 if (do_subscribe) {
410 do_subscribe = false;
411 subscribe_topic();
412 }
413 }
414
415 cleanup:
416 mqtt_disconnect(&client_ctx);
417
418 close(fds.fd);
419 fds.fd = -1;
420 }
421
resolve_broker_addr(struct sockaddr_in * broker)422 static int resolve_broker_addr(struct sockaddr_in *broker)
423 {
424 int ret;
425 struct addrinfo *ai = NULL;
426
427 const struct addrinfo hints = {
428 .ai_family = AF_INET,
429 .ai_socktype = SOCK_STREAM,
430 .ai_protocol = 0,
431 };
432 char port_string[6] = {0};
433
434 sprintf(port_string, "%d", AWS_BROKER_PORT);
435 ret = getaddrinfo(CONFIG_AWS_ENDPOINT, port_string, &hints, &ai);
436 if (ret == 0) {
437 char addr_str[INET_ADDRSTRLEN];
438
439 memcpy(broker, ai->ai_addr, MIN(ai->ai_addrlen, sizeof(struct sockaddr_storage)));
440
441 inet_ntop(AF_INET, &broker->sin_addr, addr_str, sizeof(addr_str));
442 LOG_INF("Resolved: %s:%u", addr_str, htons(broker->sin_port));
443 } else {
444 LOG_ERR("failed to resolve hostname err = %d (errno = %d)", ret, errno);
445 }
446
447 freeaddrinfo(ai);
448
449 return ret;
450 }
451
main(void)452 int main(void)
453 {
454 setup_credentials();
455
456 for (;;) {
457 resolve_broker_addr(&aws_broker);
458
459 aws_client_loop();
460
461 #if defined(CONFIG_MBEDTLS_MEMORY_DEBUG)
462 size_t cur_used, cur_blocks, max_used, max_blocks;
463
464 mbedtls_memory_buffer_alloc_cur_get(&cur_used, &cur_blocks);
465 mbedtls_memory_buffer_alloc_max_get(&max_used, &max_blocks);
466 LOG_INF("mbedTLS heap usage: MAX %u/%u (%u) CUR %u (%u)", max_used,
467 CONFIG_MBEDTLS_HEAP_SIZE, max_blocks, cur_used, cur_blocks);
468 #endif
469
470 k_sleep(K_SECONDS(1));
471 }
472
473 return 0;
474 }
475