1 /*
2  * Copyright (c) 2022 René Beckmann
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/net/mqtt_sn.h>
8 #include <zephyr/sys/util.h> /* for ARRAY_SIZE */
9 #include <zephyr/tc_util.h>
10 #include <zephyr/ztest.h>
11 
12 #include <mqtt_sn_msg.h>
13 
14 #include <zephyr/logging/log.h>
15 LOG_MODULE_REGISTER(test);
16 
17 static const struct mqtt_sn_data client_id = MQTT_SN_DATA_STRING_LITERAL("zephyr");
18 static const struct mqtt_sn_data client2_id = MQTT_SN_DATA_STRING_LITERAL("zephyr2");
19 static const uint8_t gw_id = 12;
20 static const struct mqtt_sn_data gw_addr = MQTT_SN_DATA_STRING_LITERAL("gw1");
21 
22 static uint8_t tx[255];
23 static uint8_t rx[255];
24 
25 static struct msg_send_data {
26 	int called;
27 	size_t msg_sz;
28 	int ret;
29 	const void *dest_addr;
30 	size_t addrlen;
31 	struct mqtt_sn_client *client;
32 } msg_send_data;
33 
34 struct k_sem mqtt_sn_tx_sem;
35 struct k_sem mqtt_sn_rx_sem;
36 struct k_sem mqtt_sn_cb_sem;
37 
mqtt_sn_data_cmp(struct mqtt_sn_data data1,struct mqtt_sn_data data2)38 int mqtt_sn_data_cmp(struct mqtt_sn_data data1, struct mqtt_sn_data data2)
39 {
40 	return data1.size == data2.size && strncmp(data1.data, data2.data, data1.size);
41 }
42 
msg_sendto(struct mqtt_sn_client * client,void * buf,size_t sz,const void * dest_addr,size_t addrlen)43 static int msg_sendto(struct mqtt_sn_client *client, void *buf, size_t sz, const void *dest_addr,
44 		      size_t addrlen)
45 {
46 	msg_send_data.called++;
47 	msg_send_data.msg_sz = sz;
48 	msg_send_data.client = client;
49 	msg_send_data.dest_addr = dest_addr;
50 	msg_send_data.addrlen = addrlen;
51 
52 	k_sem_give(&mqtt_sn_tx_sem);
53 
54 	return msg_send_data.ret;
55 }
56 
assert_msg_send(int called,size_t msg_sz,const struct mqtt_sn_data * dest_addr)57 static void assert_msg_send(int called, size_t msg_sz, const struct mqtt_sn_data *dest_addr)
58 {
59 	zassert_equal(msg_send_data.called, called, "msg_send called %d times instead of %d",
60 		      msg_send_data.called, called);
61 	zassert_equal(msg_send_data.msg_sz, msg_sz, "msg_sz is %zu instead of %zu",
62 		      msg_send_data.msg_sz, msg_sz);
63 	if (dest_addr != NULL) {
64 		zassert_equal(mqtt_sn_data_cmp(*dest_addr,
65 					       *((struct mqtt_sn_data *)msg_send_data.dest_addr)),
66 			      0, "Addresses incorrect");
67 	}
68 
69 	memset(&msg_send_data, 0, sizeof(msg_send_data));
70 }
71 
72 static struct {
73 	struct mqtt_sn_evt last_evt;
74 	int called;
75 } evt_cb_data;
76 
evt_cb(struct mqtt_sn_client * client,const struct mqtt_sn_evt * evt)77 static void evt_cb(struct mqtt_sn_client *client, const struct mqtt_sn_evt *evt)
78 {
79 	memcpy(&evt_cb_data.last_evt, evt, sizeof(*evt));
80 	evt_cb_data.called++;
81 
82 	k_sem_give(&mqtt_sn_cb_sem);
83 }
84 
85 static bool tp_initialized;
86 static struct mqtt_sn_transport transport;
87 
tp_init(struct mqtt_sn_transport * tp)88 static int tp_init(struct mqtt_sn_transport *tp)
89 {
90 	tp_initialized = true;
91 	return 0;
92 }
93 
94 static struct {
95 	void *data;
96 	ssize_t sz;
97 	const void *src_addr;
98 	size_t addrlen;
99 } recvfrom_data;
100 
tp_recvfrom(struct mqtt_sn_client * client,void * buffer,size_t length,void * src_addr,size_t * addrlen)101 static ssize_t tp_recvfrom(struct mqtt_sn_client *client, void *buffer, size_t length,
102 			   void *src_addr, size_t *addrlen)
103 {
104 	if (recvfrom_data.data && recvfrom_data.sz > 0 && length >= recvfrom_data.sz) {
105 		memcpy(buffer, recvfrom_data.data, recvfrom_data.sz);
106 		memcpy(src_addr, recvfrom_data.src_addr, recvfrom_data.addrlen);
107 		*addrlen = recvfrom_data.addrlen;
108 
109 		k_sem_give(&mqtt_sn_rx_sem);
110 	}
111 
112 	return recvfrom_data.sz;
113 }
114 
tp_poll(struct mqtt_sn_client * client)115 int tp_poll(struct mqtt_sn_client *client)
116 {
117 	return recvfrom_data.sz;
118 }
119 
120 static ZTEST_BMEM struct mqtt_sn_client mqtt_clients[8];
121 static ZTEST_BMEM struct mqtt_sn_client *mqtt_client;
122 
setup(void * f)123 static void setup(void *f)
124 {
125 	ARG_UNUSED(f);
126 	static ZTEST_BMEM size_t i;
127 
128 	mqtt_client = &mqtt_clients[i++];
129 
130 	transport = (struct mqtt_sn_transport){
131 		.init = tp_init, .sendto = msg_sendto, .recvfrom = tp_recvfrom, .poll = tp_poll};
132 	tp_initialized = false;
133 
134 	memset(&evt_cb_data, 0, sizeof(evt_cb_data));
135 	memset(&msg_send_data, 0, sizeof(msg_send_data));
136 	memset(&recvfrom_data, 0, sizeof(recvfrom_data));
137 	k_sem_init(&mqtt_sn_tx_sem, 0, 1);
138 	k_sem_init(&mqtt_sn_rx_sem, 0, 1);
139 	k_sem_init(&mqtt_sn_cb_sem, 0, 1);
140 }
141 
input(struct mqtt_sn_client * client,void * buf,size_t sz,const struct mqtt_sn_data * src_addr)142 static int input(struct mqtt_sn_client *client, void *buf, size_t sz,
143 		 const struct mqtt_sn_data *src_addr)
144 {
145 	recvfrom_data.data = buf;
146 	recvfrom_data.sz = sz;
147 	recvfrom_data.src_addr = src_addr->data;
148 	recvfrom_data.addrlen = src_addr->size;
149 
150 	return mqtt_sn_input(client);
151 }
152 
mqtt_sn_connect_no_will(struct mqtt_sn_client * client)153 static void mqtt_sn_connect_no_will(struct mqtt_sn_client *client)
154 {
155 	/* connack with return code accepted */
156 	static uint8_t connack[] = {3, 0x05, 0x00};
157 	int err;
158 
159 	err = mqtt_sn_client_init(client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
160 				  sizeof(rx));
161 	zassert_equal(err, 0, "unexpected error %d");
162 	zassert_true(tp_initialized, "Transport not initialized");
163 
164 	err = mqtt_sn_add_gw(client, gw_id, gw_addr);
165 	zassert_equal(err, 0, "unexpected error %d");
166 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
167 	zassert_false(sys_slist_is_empty(&client->gateway), "GW not saved.");
168 
169 	err = mqtt_sn_connect(client, false, false);
170 	zassert_equal(err, 0, "unexpected error %d");
171 	assert_msg_send(1, 12, &gw_addr);
172 	zassert_equal(client->state, 0, "Wrong state");
173 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
174 
175 	err = input(client, connack, sizeof(connack), &gw_addr);
176 	zassert_equal(err, 0, "unexpected error %d");
177 	zassert_equal(client->state, 1, "Wrong state");
178 	zassert_equal(evt_cb_data.called, 1, "NO event");
179 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
180 }
181 
ZTEST(mqtt_sn_client,test_mqtt_sn_handle_advertise)182 static ZTEST(mqtt_sn_client, test_mqtt_sn_handle_advertise)
183 {
184 	static uint8_t advertise[] = {5, 0x00, 0x0c, 0, 1};
185 	static uint8_t connack[] = {3, 0x05, 0x00};
186 	int err;
187 
188 	err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
189 				  sizeof(rx));
190 	zassert_equal(err, 0, "unexpected error %d");
191 
192 	err = input(mqtt_client, advertise, sizeof(advertise), &gw_addr);
193 	zassert_equal(err, 0, "unexpected error %d");
194 	zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
195 	zassert_equal(evt_cb_data.called, 1, "NO event");
196 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_ADVERTISE, "Wrong event");
197 
198 	err = input(mqtt_client, advertise, sizeof(advertise), &gw_addr);
199 	zassert_equal(err, 0, "unexpected error %d");
200 	zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
201 	zassert_equal(sys_slist_len(&mqtt_client->gateway), 1, "Too many Gateways stored.");
202 	zassert_equal(evt_cb_data.called, 2, "Unexpected event");
203 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_ADVERTISE, "Wrong event");
204 
205 	err = mqtt_sn_connect(mqtt_client, false, false);
206 	zassert_equal(err, 0, "unexpected error %d");
207 	assert_msg_send(1, 12, &gw_addr);
208 	zassert_equal(mqtt_client->state, 0, "Wrong state");
209 	zassert_equal(evt_cb_data.called, 2, "Unexpected event");
210 
211 	err = input(mqtt_client, connack, sizeof(connack), &gw_addr);
212 	zassert_equal(err, 0, "unexpected error %d");
213 	zassert_equal(mqtt_client->state, 1, "Wrong state");
214 	zassert_equal(evt_cb_data.called, 3, "NO event");
215 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
216 
217 	err = k_sem_take(&mqtt_sn_cb_sem, K_NO_WAIT);
218 	err = k_sem_take(&mqtt_sn_cb_sem, K_SECONDS(10));
219 	zassert_equal(err, 0, "Timed out waiting for callback.");
220 
221 	zassert_true(sys_slist_is_empty(&mqtt_client->gateway), "GW not cleared on timeout");
222 	zassert_equal(evt_cb_data.called, 4, "NO event");
223 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_DISCONNECTED, "Wrong event");
224 	zassert_equal(mqtt_client->state, 0, "Wrong state");
225 
226 	mqtt_sn_client_deinit(mqtt_client);
227 }
228 
ZTEST(mqtt_sn_client,test_mqtt_sn_add_gw)229 static ZTEST(mqtt_sn_client, test_mqtt_sn_add_gw)
230 {
231 	int err;
232 
233 	err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
234 				  sizeof(rx));
235 	zassert_equal(err, 0, "unexpected error %d");
236 
237 	err = mqtt_sn_add_gw(mqtt_client, gw_id, gw_addr);
238 	zassert_equal(err, 0, "unexpected error %d");
239 	zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
240 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
241 
242 	mqtt_sn_client_deinit(mqtt_client);
243 }
244 
245 /* Test send SEARCHGW and GW response */
ZTEST(mqtt_sn_client,test_mqtt_sn_search_gw)246 static ZTEST(mqtt_sn_client, test_mqtt_sn_search_gw)
247 {
248 	int err;
249 	static uint8_t gwinfo[3];
250 
251 	gwinfo[0] = 3;
252 	gwinfo[1] = 0x02;
253 	gwinfo[2] = gw_id;
254 
255 	err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
256 				  sizeof(rx));
257 	zassert_equal(err, 0, "unexpected error %d");
258 
259 	err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
260 	err = mqtt_sn_search(mqtt_client, 1);
261 	zassert_equal(err, 0, "unexpected error %d");
262 
263 	err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
264 	zassert_equal(err, 0, "Timed out waiting for callback.");
265 
266 	assert_msg_send(1, 3, NULL);
267 	zassert_equal(mqtt_client->state, 0, "Wrong state");
268 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
269 
270 	err = input(mqtt_client, gwinfo, sizeof(gwinfo), &gw_addr);
271 	zassert_equal(err, 0, "unexpected error %d");
272 	zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
273 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_GWINFO, "Wrong event");
274 
275 	mqtt_sn_client_deinit(mqtt_client);
276 }
277 
278 /* Test send SEARCHGW and peer response */
ZTEST(mqtt_sn_client,test_mqtt_sn_search_peer)279 static ZTEST(mqtt_sn_client, test_mqtt_sn_search_peer)
280 {
281 	int err;
282 	static uint8_t gwinfo[3 + 3];
283 
284 	gwinfo[0] = 3 + gw_addr.size;
285 	gwinfo[1] = 0x02;
286 	gwinfo[2] = gw_id;
287 	memcpy(&gwinfo[3], gw_addr.data, 3);
288 
289 	err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
290 				  sizeof(rx));
291 	zassert_equal(err, 0, "unexpected error %d");
292 
293 	err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
294 	err = mqtt_sn_search(mqtt_client, 1);
295 	zassert_equal(err, 0, "unexpected error %d");
296 
297 	err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
298 	zassert_equal(err, 0, "Timed out waiting for callback.");
299 
300 	assert_msg_send(1, 3, NULL);
301 	zassert_equal(mqtt_client->state, 0, "Wrong state");
302 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
303 
304 	err = input(mqtt_client, gwinfo, sizeof(gwinfo), &gw_addr);
305 	zassert_equal(err, 0, "unexpected error %d");
306 	zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
307 	zassert_equal(evt_cb_data.called, 1, "NO event");
308 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_GWINFO, "Wrong event");
309 
310 	mqtt_sn_client_deinit(mqtt_client);
311 }
312 
ZTEST(mqtt_sn_client,test_mqtt_sn_respond_searchgw)313 static ZTEST(mqtt_sn_client, test_mqtt_sn_respond_searchgw)
314 {
315 	int err;
316 	static uint8_t searchgw[] = {3, 0x01, 1};
317 
318 	err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
319 				  sizeof(rx));
320 	zassert_equal(err, 0, "unexpected error %d");
321 
322 	err = mqtt_sn_add_gw(mqtt_client, gw_id, gw_addr);
323 	zassert_equal(err, 0, "unexpected error %d");
324 	zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
325 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
326 
327 	err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
328 	err = input(mqtt_client, searchgw, sizeof(searchgw), &client2_id);
329 	zassert_equal(err, 0, "unexpected error %d");
330 
331 	err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
332 	zassert_equal(err, 0, "Timed out waiting for callback.");
333 
334 	zassert_equal(evt_cb_data.called, 1, "NO event");
335 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_SEARCHGW, "Wrong event");
336 	assert_msg_send(1, 3 + gw_addr.size, NULL);
337 
338 	mqtt_sn_client_deinit(mqtt_client);
339 }
340 
ZTEST(mqtt_sn_client,test_mqtt_sn_connect_no_will)341 static ZTEST(mqtt_sn_client, test_mqtt_sn_connect_no_will)
342 {
343 	mqtt_sn_connect_no_will(mqtt_client);
344 	mqtt_sn_client_deinit(mqtt_client);
345 }
346 
ZTEST(mqtt_sn_client,test_mqtt_sn_connect_will)347 static ZTEST(mqtt_sn_client, test_mqtt_sn_connect_will)
348 {
349 	static uint8_t willtopicreq[] = {2, 0x06};
350 	static uint8_t willmsgreq[] = {2, 0x08};
351 	static uint8_t connack[] = {3, 0x05, 0x00};
352 
353 	int err;
354 
355 	err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
356 				  sizeof(rx));
357 	zassert_equal(err, 0, "unexpected error %d");
358 
359 	err = mqtt_sn_add_gw(mqtt_client, gw_id, gw_addr);
360 	zassert_equal(err, 0, "unexpected error %d");
361 	zassert_false(sys_slist_is_empty(&mqtt_client->gateway), "GW not saved.");
362 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
363 
364 	mqtt_client->will_topic = MQTT_SN_DATA_STRING_LITERAL("topic");
365 	mqtt_client->will_msg = MQTT_SN_DATA_STRING_LITERAL("msg");
366 
367 	err = mqtt_sn_connect(mqtt_client, true, false);
368 	zassert_equal(err, 0, "unexpected error %d");
369 	assert_msg_send(1, 12, &gw_addr);
370 	zassert_equal(mqtt_client->state, 0, "Wrong state");
371 
372 	err = input(mqtt_client, willtopicreq, sizeof(willtopicreq), &gw_addr);
373 	zassert_equal(err, 0, "unexpected error %d");
374 	zassert_equal(mqtt_client->state, 0, "Wrong state");
375 	assert_msg_send(1, 8, &gw_addr);
376 
377 	err = input(mqtt_client, willmsgreq, sizeof(willmsgreq), &gw_addr);
378 	zassert_equal(err, 0, "unexpected error %d");
379 	zassert_equal(mqtt_client->state, 0, "Wrong state");
380 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
381 	assert_msg_send(1, 5, &gw_addr);
382 
383 	err = input(mqtt_client, connack, sizeof(connack), &gw_addr);
384 	zassert_equal(err, 0, "unexpected error %d");
385 	zassert_equal(mqtt_client->state, 1, "Wrong state");
386 	zassert_equal(evt_cb_data.called, 1, "NO event");
387 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
388 
389 	mqtt_sn_client_deinit(mqtt_client);
390 }
391 
ZTEST(mqtt_sn_client,test_mqtt_sn_publish_qos0)392 static ZTEST(mqtt_sn_client, test_mqtt_sn_publish_qos0)
393 {
394 	struct mqtt_sn_data data = MQTT_SN_DATA_STRING_LITERAL("Hello, World!");
395 	struct mqtt_sn_data topic = MQTT_SN_DATA_STRING_LITERAL("zephyr");
396 	/* registration ack with topic ID 0x1A1B, msg ID 0x0001, return code accepted */
397 	uint8_t regack[] = {7, 0x0B, 0x1A, 0x1B, 0x00, 0x01, 0};
398 	int err;
399 
400 	mqtt_sn_connect_no_will(mqtt_client);
401 	err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
402 	err = mqtt_sn_publish(mqtt_client, MQTT_SN_QOS_0, &topic, false, &data);
403 	zassert_equal(err, 0, "Unexpected error %d", err);
404 
405 	assert_msg_send(0, 0, NULL);
406 
407 	/* Expect a REGISTER to be sent */
408 	err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
409 	zassert_equal(err, 0, "Timed out waiting for callback.");
410 	assert_msg_send(1, 12, &gw_addr);
411 	err = input(mqtt_client, regack, sizeof(regack), &gw_addr);
412 	zassert_equal(err, 0, "unexpected error %d");
413 	err = k_sem_take(&mqtt_sn_tx_sem, K_NO_WAIT);
414 	assert_msg_send(0, 0, NULL);
415 	err = k_sem_take(&mqtt_sn_tx_sem, K_SECONDS(10));
416 	zassert_equal(err, 0, "Timed out waiting for callback.");
417 	assert_msg_send(1, 20, &gw_addr);
418 
419 	zassert_true(sys_slist_is_empty(&mqtt_client->publish), "Publish not empty");
420 	zassert_false(sys_slist_is_empty(&mqtt_client->topic), "Topic empty");
421 
422 	mqtt_sn_client_deinit(mqtt_client);
423 }
424 
425 ZTEST_SUITE(mqtt_sn_client, NULL, NULL, setup, NULL, NULL);
426