1 /*
2 * Copyright (c) 2022 René Beckmann
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <zephyr/net/mqtt_sn.h>
8 #include <zephyr/sys/util.h> /* for ARRAY_SIZE */
9 #include <zephyr/tc_util.h>
10 #include <zephyr/ztest.h>
11
12 #include <mqtt_sn_msg.h>
13
14 #include <zephyr/logging/log.h>
15 LOG_MODULE_REGISTER(test);
16
17 static const struct mqtt_sn_data client_id = MQTT_SN_DATA_STRING_LITERAL("zephyr");
18 static const struct mqtt_sn_data client2_id = MQTT_SN_DATA_STRING_LITERAL("zephyr2");
19 static const uint8_t gw_id = 12;
20 static const struct mqtt_sn_data gw_addr = MQTT_SN_DATA_STRING_LITERAL("gw1");
21
22 static uint8_t tx[255];
23 static uint8_t rx[255];
24
25 static struct msg_send_data {
26 int called;
27 size_t msg_sz;
28 int ret;
29 const void *dest_addr;
30 size_t addrlen;
31 struct mqtt_sn_client *client;
32 } msg_send_data;
33
34 struct k_sem mqtt_sn_tx_sem;
35 struct k_sem mqtt_sn_rx_sem;
36 struct k_sem mqtt_sn_cb_sem;
37
mqtt_sn_data_cmp(struct mqtt_sn_data data1,struct mqtt_sn_data data2)38 int mqtt_sn_data_cmp(struct mqtt_sn_data data1, struct mqtt_sn_data data2)
39 {
40 return data1.size == data2.size && strncmp(data1.data, data2.data, data1.size);
41 }
42
msg_sendto(struct mqtt_sn_client * client,void * buf,size_t sz,const void * dest_addr,size_t addrlen)43 static int msg_sendto(struct mqtt_sn_client *client, void *buf, size_t sz, const void *dest_addr,
44 size_t addrlen)
45 {
46 msg_send_data.called++;
47 msg_send_data.msg_sz = sz;
48 msg_send_data.client = client;
49 msg_send_data.dest_addr = dest_addr;
50 msg_send_data.addrlen = addrlen;
51
52 k_sem_give(&mqtt_sn_tx_sem);
53
54 return msg_send_data.ret;
55 }
56
assert_msg_send(int called,size_t msg_sz,const struct mqtt_sn_data * dest_addr)57 static void assert_msg_send(int called, size_t msg_sz, const struct mqtt_sn_data *dest_addr)
58 {
59 zassert_equal(msg_send_data.called, called, "msg_send called %d times instead of %d",
60 msg_send_data.called, called);
61 zassert_equal(msg_send_data.msg_sz, msg_sz, "msg_sz is %zu instead of %zu",
62 msg_send_data.msg_sz, msg_sz);
63 if (dest_addr != NULL) {
64 zassert_equal(mqtt_sn_data_cmp(*dest_addr,
65 *((struct mqtt_sn_data *)msg_send_data.dest_addr)),
66 0, "Addresses incorrect");
67 }
68
69 memset(&msg_send_data, 0, sizeof(msg_send_data));
70 }
71
72 static struct {
73 struct mqtt_sn_evt last_evt;
74 int called;
75 } evt_cb_data;
76
evt_cb(struct mqtt_sn_client * client,const struct mqtt_sn_evt * evt)77 static void evt_cb(struct mqtt_sn_client *client, const struct mqtt_sn_evt *evt)
78 {
79 memcpy(&evt_cb_data.last_evt, evt, sizeof(*evt));
80 evt_cb_data.called++;
81
82 k_sem_give(&mqtt_sn_cb_sem);
83 }
84
85 static bool tp_initialized;
86 static struct mqtt_sn_transport transport;
87
tp_init(struct mqtt_sn_transport * tp)88 static int tp_init(struct mqtt_sn_transport *tp)
89 {
90 tp_initialized = true;
91 return 0;
92 }
93
94 static struct {
95 void *data;
96 ssize_t sz;
97 const void *src_addr;
98 size_t addrlen;
99 } recvfrom_data;
100
tp_recvfrom(struct mqtt_sn_client * client,void * buffer,size_t length,void * src_addr,size_t * addrlen)101 static ssize_t tp_recvfrom(struct mqtt_sn_client *client, void *buffer, size_t length,
102 void *src_addr, size_t *addrlen)
103 {
104 if (recvfrom_data.data && recvfrom_data.sz > 0 && length >= recvfrom_data.sz) {
105 memcpy(buffer, recvfrom_data.data, recvfrom_data.sz);
106 memcpy(src_addr, recvfrom_data.src_addr, recvfrom_data.addrlen);
107 *addrlen = recvfrom_data.addrlen;
108
109 k_sem_give(&mqtt_sn_rx_sem);
110 }
111
112 return recvfrom_data.sz;
113 }
114
tp_poll(struct mqtt_sn_client * client)115 int tp_poll(struct mqtt_sn_client *client)
116 {
117 return recvfrom_data.sz;
118 }
119
120 static ZTEST_BMEM struct mqtt_sn_client mqtt_clients[8];
121 static ZTEST_BMEM struct mqtt_sn_client *mqtt_client;
122
setup(void * f)123 static void setup(void *f)
124 {
125 ARG_UNUSED(f);
126 static ZTEST_BMEM size_t i;
127
128 mqtt_client = &mqtt_clients[i++];
129
130 transport = (struct mqtt_sn_transport){
131 .init = tp_init, .sendto = msg_sendto, .recvfrom = tp_recvfrom, .poll = tp_poll};
132 tp_initialized = false;
133
134 memset(&evt_cb_data, 0, sizeof(evt_cb_data));
135 memset(&msg_send_data, 0, sizeof(msg_send_data));
136 memset(&recvfrom_data, 0, sizeof(recvfrom_data));
137 k_sem_init(&mqtt_sn_tx_sem, 0, 1);
138 k_sem_init(&mqtt_sn_rx_sem, 0, 1);
139 k_sem_init(&mqtt_sn_cb_sem, 0, 1);
140 }
141
input(struct mqtt_sn_client * client,void * buf,size_t sz,const struct mqtt_sn_data * src_addr)142 static int input(struct mqtt_sn_client *client, void *buf, size_t sz,
143 const struct mqtt_sn_data *src_addr)
144 {
145 recvfrom_data.data = buf;
146 recvfrom_data.sz = sz;
147 recvfrom_data.src_addr = src_addr->data;
148 recvfrom_data.addrlen = src_addr->size;
149
150 return mqtt_sn_input(client);
151 }
152
mqtt_sn_connect_no_will(struct mqtt_sn_client * client)153 static void mqtt_sn_connect_no_will(struct mqtt_sn_client *client)
154 {
155 /* connack with return code accepted */
156 static uint8_t connack[] = {3, 0x05, 0x00};
157 int err;
158
159 err = mqtt_sn_client_init(client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
160 sizeof(rx));
161 zassert_equal(err, 0, "unexpected error %d");
162 zassert_true(tp_initialized, "Transport not initialized");
163
164 err = mqtt_sn_add_gw(client, gw_id, gw_addr);
165 zassert_equal(err, 0, "unexpected error %d");
166 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
167 zassert_false(sys_slist_is_empty(&client->gateway), "GW not saved.");
168
169 err = mqtt_sn_connect(client, false, false);
170 zassert_equal(err, 0, "unexpected error %d");
171 assert_msg_send(1, 12, &gw_addr);
172 zassert_equal(client->state, 0, "Wrong state");
173 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
174
175 err = input(client, connack, sizeof(connack), &gw_addr);
176 zassert_equal(err, 0, "unexpected error %d");
177 zassert_equal(client->state, 1, "Wrong state");
178 zassert_equal(evt_cb_data.called, 1, "NO event");
179 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
180 }
181
ZTEST(mqtt_sn_client,test_mqtt_sn_handle_advertise)182 static ZTEST(mqtt_sn_client, test_mqtt_sn_handle_advertise)
183 {
184 static uint8_t advertise[] = {5, 0x00, 0x0c, 0, 1};
185 static uint8_t connack[] = {3, 0x05, 0x00};
186 int err;
187
188 err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
189 sizeof(rx));
190 zassert_equal(err, 0, "unexpected error %d");
191
192 err = input(mqtt_client, advertise, sizeof(advertise), &gw_addr);
193 zassert_equal(err, 0, "unexpected error %d");
194 zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
195 zassert_equal(evt_cb_data.called, 1, "NO event");
196 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_ADVERTISE, "Wrong event");
197
198 err = input(mqtt_client, advertise, sizeof(advertise), &gw_addr);
199 zassert_equal(err, 0, "unexpected error %d");
200 zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
201 zassert_equal(sys_slist_len(&mqtt_client->gateway), 1, "Too many Gateways stored.");
202 zassert_equal(evt_cb_data.called, 2, "Unexpected event");
203 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_ADVERTISE, "Wrong event");
204
205 err = mqtt_sn_connect(mqtt_client, false, false);
206 zassert_equal(err, 0, "unexpected error %d");
207 assert_msg_send(1, 12, &gw_addr);
208 zassert_equal(mqtt_client->state, 0, "Wrong state");
209 zassert_equal(evt_cb_data.called, 2, "Unexpected event");
210
211 err = input(mqtt_client, connack, sizeof(connack), &gw_addr);
212 zassert_equal(err, 0, "unexpected error %d");
213 zassert_equal(mqtt_client->state, 1, "Wrong state");
214 zassert_equal(evt_cb_data.called, 3, "NO event");
215 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
216
217 err = k_sem_take(&mqtt_sn_cb_sem, K_NO_WAIT);
218 err = k_sem_take(&mqtt_sn_cb_sem, K_SECONDS(10));
219 zassert_equal(err, 0, "Timed out waiting for callback.");
220
221 zassert_true(sys_slist_is_empty(&mqtt_client->gateway), "GW not cleared on timeout");
222 zassert_equal(evt_cb_data.called, 4, "NO event");
223 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_DISCONNECTED, "Wrong event");
224 zassert_equal(mqtt_client->state, 0, "Wrong state");
225
226 mqtt_sn_client_deinit(mqtt_client);
227 }
228
ZTEST(mqtt_sn_client,test_mqtt_sn_add_gw)229 static ZTEST(mqtt_sn_client, test_mqtt_sn_add_gw)
230 {
231 int err;
232
233 err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
234 sizeof(rx));
235 zassert_equal(err, 0, "unexpected error %d");
236
237 err = mqtt_sn_add_gw(mqtt_client, gw_id, gw_addr);
238 zassert_equal(err, 0, "unexpected error %d");
239 zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
240 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
241
242 mqtt_sn_client_deinit(mqtt_client);
243 }
244
245 /* Test send SEARCHGW and GW response */
ZTEST(mqtt_sn_client,test_mqtt_sn_search_gw)246 static ZTEST(mqtt_sn_client, test_mqtt_sn_search_gw)
247 {
248 int err;
249 static uint8_t gwinfo[3];
250
251 gwinfo[0] = 3;
252 gwinfo[1] = 0x02;
253 gwinfo[2] = gw_id;
254
255 err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
256 sizeof(rx));
257 zassert_equal(err, 0, "unexpected error %d");
258
259 err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
260 err = mqtt_sn_search(mqtt_client, 1);
261 zassert_equal(err, 0, "unexpected error %d");
262
263 err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
264 zassert_equal(err, 0, "Timed out waiting for callback.");
265
266 assert_msg_send(1, 3, NULL);
267 zassert_equal(mqtt_client->state, 0, "Wrong state");
268 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
269
270 err = input(mqtt_client, gwinfo, sizeof(gwinfo), &gw_addr);
271 zassert_equal(err, 0, "unexpected error %d");
272 zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
273 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_GWINFO, "Wrong event");
274
275 mqtt_sn_client_deinit(mqtt_client);
276 }
277
278 /* Test send SEARCHGW and peer response */
ZTEST(mqtt_sn_client,test_mqtt_sn_search_peer)279 static ZTEST(mqtt_sn_client, test_mqtt_sn_search_peer)
280 {
281 int err;
282 static uint8_t gwinfo[3 + 3];
283
284 gwinfo[0] = 3 + gw_addr.size;
285 gwinfo[1] = 0x02;
286 gwinfo[2] = gw_id;
287 memcpy(&gwinfo[3], gw_addr.data, 3);
288
289 err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
290 sizeof(rx));
291 zassert_equal(err, 0, "unexpected error %d");
292
293 err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
294 err = mqtt_sn_search(mqtt_client, 1);
295 zassert_equal(err, 0, "unexpected error %d");
296
297 err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
298 zassert_equal(err, 0, "Timed out waiting for callback.");
299
300 assert_msg_send(1, 3, NULL);
301 zassert_equal(mqtt_client->state, 0, "Wrong state");
302 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
303
304 err = input(mqtt_client, gwinfo, sizeof(gwinfo), &gw_addr);
305 zassert_equal(err, 0, "unexpected error %d");
306 zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
307 zassert_equal(evt_cb_data.called, 1, "NO event");
308 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_GWINFO, "Wrong event");
309
310 mqtt_sn_client_deinit(mqtt_client);
311 }
312
ZTEST(mqtt_sn_client,test_mqtt_sn_respond_searchgw)313 static ZTEST(mqtt_sn_client, test_mqtt_sn_respond_searchgw)
314 {
315 int err;
316 static uint8_t searchgw[] = {3, 0x01, 1};
317
318 err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
319 sizeof(rx));
320 zassert_equal(err, 0, "unexpected error %d");
321
322 err = mqtt_sn_add_gw(mqtt_client, gw_id, gw_addr);
323 zassert_equal(err, 0, "unexpected error %d");
324 zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
325 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
326
327 err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
328 err = input(mqtt_client, searchgw, sizeof(searchgw), &client2_id);
329 zassert_equal(err, 0, "unexpected error %d");
330
331 err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
332 zassert_equal(err, 0, "Timed out waiting for callback.");
333
334 zassert_equal(evt_cb_data.called, 1, "NO event");
335 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_SEARCHGW, "Wrong event");
336 assert_msg_send(1, 3 + gw_addr.size, NULL);
337
338 mqtt_sn_client_deinit(mqtt_client);
339 }
340
ZTEST(mqtt_sn_client,test_mqtt_sn_connect_no_will)341 static ZTEST(mqtt_sn_client, test_mqtt_sn_connect_no_will)
342 {
343 mqtt_sn_connect_no_will(mqtt_client);
344 mqtt_sn_client_deinit(mqtt_client);
345 }
346
ZTEST(mqtt_sn_client,test_mqtt_sn_connect_will)347 static ZTEST(mqtt_sn_client, test_mqtt_sn_connect_will)
348 {
349 static uint8_t willtopicreq[] = {2, 0x06};
350 static uint8_t willmsgreq[] = {2, 0x08};
351 static uint8_t connack[] = {3, 0x05, 0x00};
352
353 int err;
354
355 err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
356 sizeof(rx));
357 zassert_equal(err, 0, "unexpected error %d");
358
359 err = mqtt_sn_add_gw(mqtt_client, gw_id, gw_addr);
360 zassert_equal(err, 0, "unexpected error %d");
361 zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
362 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
363
364 mqtt_client->will_topic = MQTT_SN_DATA_STRING_LITERAL("topic");
365 mqtt_client->will_msg = MQTT_SN_DATA_STRING_LITERAL("msg");
366
367 err = mqtt_sn_connect(mqtt_client, true, false);
368 zassert_equal(err, 0, "unexpected error %d");
369 assert_msg_send(1, 12, &gw_addr);
370 zassert_equal(mqtt_client->state, 0, "Wrong state");
371
372 err = input(mqtt_client, willtopicreq, sizeof(willtopicreq), &gw_addr);
373 zassert_equal(err, 0, "unexpected error %d");
374 zassert_equal(mqtt_client->state, 0, "Wrong state");
375 assert_msg_send(1, 8, &gw_addr);
376
377 err = input(mqtt_client, willmsgreq, sizeof(willmsgreq), &gw_addr);
378 zassert_equal(err, 0, "unexpected error %d");
379 zassert_equal(mqtt_client->state, 0, "Wrong state");
380 zassert_equal(evt_cb_data.called, 0, "Unexpected event");
381 assert_msg_send(1, 5, &gw_addr);
382
383 err = input(mqtt_client, connack, sizeof(connack), &gw_addr);
384 zassert_equal(err, 0, "unexpected error %d");
385 zassert_equal(mqtt_client->state, 1, "Wrong state");
386 zassert_equal(evt_cb_data.called, 1, "NO event");
387 zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
388
389 mqtt_sn_client_deinit(mqtt_client);
390 }
391
ZTEST(mqtt_sn_client,test_mqtt_sn_publish_qos0)392 static ZTEST(mqtt_sn_client, test_mqtt_sn_publish_qos0)
393 {
394 struct mqtt_sn_data data = MQTT_SN_DATA_STRING_LITERAL("Hello, World!");
395 struct mqtt_sn_data topic = MQTT_SN_DATA_STRING_LITERAL("zephyr");
396 /* registration ack with topic ID 0x1A1B, msg ID 0x0001, return code accepted */
397 uint8_t regack[] = {7, 0x0B, 0x1A, 0x1B, 0x00, 0x01, 0};
398 int err;
399
400 mqtt_sn_connect_no_will(mqtt_client);
401 err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
402 err = mqtt_sn_publish(mqtt_client, MQTT_SN_QOS_0, &topic, false, &data);
403 zassert_equal(err, 0, "Unexpected error %d", err);
404
405 assert_msg_send(0, 0, NULL);
406
407 /* Expect a REGISTER to be sent */
408 err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
409 zassert_equal(err, 0, "Timed out waiting for callback.");
410 assert_msg_send(1, 12, &gw_addr);
411 err = input(mqtt_client, regack, sizeof(regack), &gw_addr);
412 zassert_equal(err, 0, "unexpected error %d");
413 err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
414 assert_msg_send(0, 0, NULL);
415 err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
416 zassert_equal(err, 0, "Timed out waiting for callback.");
417 assert_msg_send(1, 20, &gw_addr);
418
419 zassert_true(sys_slist_is_empty(&mqtt_client->publish), "Publish not empty");
420 zassert_false(sys_slist_is_empty(&mqtt_client->topic), "Topic empty");
421
422 mqtt_sn_client_deinit(mqtt_client);
423 }
424
425 ZTEST_SUITE(mqtt_sn_client, NULL, NULL, setup, NULL, NULL);
426