1 /*
2  * Copyright (c) 2022 René Beckmann
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <zephyr/net/mqtt_sn.h>
8 #include <zephyr/sys/util.h> /* for ARRAY_SIZE */
9 #include <zephyr/tc_util.h>
10 #include <zephyr/ztest.h>
11 
12 #include <mqtt_sn_msg.h>
13 
14 #include <zephyr/logging/log.h>
15 LOG_MODULE_REGISTER(test);
16 
17 static const struct mqtt_sn_data client_id = MQTT_SN_DATA_STRING_LITERAL("zephyr");
18 
19 static uint8_t tx[255];
20 static uint8_t rx[255];
21 
22 static struct msg_send_data {
23 	int called;
24 	size_t msg_sz;
25 	int ret;
26 	struct mqtt_sn_client *client;
27 } msg_send_data;
28 
msg_send(struct mqtt_sn_client * client,void * buf,size_t sz)29 static int msg_send(struct mqtt_sn_client *client, void *buf, size_t sz)
30 {
31 	msg_send_data.called++;
32 	msg_send_data.msg_sz = sz;
33 	msg_send_data.client = client;
34 
35 	return msg_send_data.ret;
36 }
37 
assert_msg_send(int called,size_t msg_sz)38 static void assert_msg_send(int called, size_t msg_sz)
39 {
40 	zassert_equal(msg_send_data.called, called, "msg_send called %d times instead of %d",
41 		      msg_send_data.called, called);
42 	zassert_equal(msg_send_data.msg_sz, msg_sz, "msg_sz is %zu instead of %zu",
43 		      msg_send_data.msg_sz, msg_sz);
44 
45 	memset(&msg_send_data, 0, sizeof(msg_send_data));
46 }
47 
48 static struct {
49 	struct mqtt_sn_evt last_evt;
50 	int called;
51 } evt_cb_data;
52 
evt_cb(struct mqtt_sn_client * client,const struct mqtt_sn_evt * evt)53 static void evt_cb(struct mqtt_sn_client *client, const struct mqtt_sn_evt *evt)
54 {
55 	memcpy(&evt_cb_data.last_evt, evt, sizeof(*evt));
56 	evt_cb_data.called++;
57 }
58 
59 static bool tp_initialized;
60 static struct mqtt_sn_transport transport;
61 
tp_init(struct mqtt_sn_transport * tp)62 static int tp_init(struct mqtt_sn_transport *tp)
63 {
64 	tp_initialized = true;
65 	return 0;
66 }
67 
68 static struct {
69 	void *data;
70 	ssize_t sz;
71 } recv_data;
72 
tp_recv(struct mqtt_sn_client * client,void * buffer,size_t length)73 static ssize_t tp_recv(struct mqtt_sn_client *client, void *buffer, size_t length)
74 {
75 	if (recv_data.data && recv_data.sz > 0 && length >= recv_data.sz) {
76 		memcpy(buffer, recv_data.data, recv_data.sz);
77 	}
78 
79 	return recv_data.sz;
80 }
81 
tp_poll(struct mqtt_sn_client * client)82 int tp_poll(struct mqtt_sn_client *client)
83 {
84 	return recv_data.sz;
85 }
86 
87 static ZTEST_BMEM struct mqtt_sn_client mqtt_clients[3];
88 static ZTEST_BMEM struct mqtt_sn_client *mqtt_client;
89 
setup(void * f)90 static void setup(void *f)
91 {
92 	ARG_UNUSED(f);
93 	static ZTEST_BMEM size_t i;
94 
95 	mqtt_client = &mqtt_clients[i++];
96 
97 	transport = (struct mqtt_sn_transport){
98 		.init = tp_init, .msg_send = msg_send, .recv = tp_recv, .poll = tp_poll};
99 	tp_initialized = false;
100 
101 	memset(&evt_cb_data, 0, sizeof(evt_cb_data));
102 	memset(&msg_send_data, 0, sizeof(msg_send_data));
103 	memset(&recv_data, 0, sizeof(recv_data));
104 }
105 
input(struct mqtt_sn_client * client,void * buf,size_t sz)106 static int input(struct mqtt_sn_client *client, void *buf, size_t sz)
107 {
108 	recv_data.data = buf;
109 	recv_data.sz = sz;
110 
111 	return mqtt_sn_input(client);
112 }
113 
mqtt_sn_connect_no_will(struct mqtt_sn_client * client)114 static void mqtt_sn_connect_no_will(struct mqtt_sn_client *client)
115 {
116 	/* connack with return code accepted */
117 	static uint8_t connack[] = {3, 0x05, 0x00};
118 	int err;
119 
120 	err = mqtt_sn_client_init(client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
121 				  sizeof(rx));
122 	zassert_equal(err, 0, "unexpected error %d");
123 	zassert_true(tp_initialized, "Transport not initialized");
124 
125 	err = mqtt_sn_connect(client, false, false);
126 	zassert_equal(err, 0, "unexpected error %d");
127 	assert_msg_send(1, 12);
128 	zassert_equal(client->state, 0, "Wrong state");
129 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
130 
131 	err = input(client, connack, sizeof(connack));
132 	zassert_equal(err, 0, "unexpected error %d");
133 	zassert_equal(client->state, 1, "Wrong state");
134 	zassert_equal(evt_cb_data.called, 1, "NO event");
135 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
136 	k_sleep(K_MSEC(10));
137 }
138 
ZTEST(mqtt_sn_client,test_mqtt_sn_connect_no_will)139 static ZTEST(mqtt_sn_client, test_mqtt_sn_connect_no_will)
140 {
141 
142 	mqtt_sn_connect_no_will(mqtt_client);
143 }
144 
ZTEST(mqtt_sn_client,test_mqtt_sn_connect_will)145 static ZTEST(mqtt_sn_client, test_mqtt_sn_connect_will)
146 {
147 	static uint8_t willtopicreq[] = {2, 0x06};
148 	static uint8_t willmsgreq[] = {2, 0x08};
149 	static uint8_t connack[] = {3, 0x05, 0x00};
150 
151 	int err;
152 
153 	err = mqtt_sn_client_init(mqtt_client, &client_id, &transport, evt_cb, tx, sizeof(tx), rx,
154 				  sizeof(rx));
155 	zassert_equal(err, 0, "unexpected error %d");
156 
157 	mqtt_client->will_topic = MQTT_SN_DATA_STRING_LITERAL("topic");
158 	mqtt_client->will_msg = MQTT_SN_DATA_STRING_LITERAL("msg");
159 
160 	err = mqtt_sn_connect(mqtt_client, true, false);
161 	zassert_equal(err, 0, "unexpected error %d");
162 	assert_msg_send(1, 12);
163 	zassert_equal(mqtt_client->state, 0, "Wrong state");
164 
165 	err = input(mqtt_client, willtopicreq, sizeof(willtopicreq));
166 	zassert_equal(err, 0, "unexpected error %d");
167 	zassert_equal(mqtt_client->state, 0, "Wrong state");
168 	assert_msg_send(1, 8);
169 
170 	err = input(mqtt_client, willmsgreq, sizeof(willmsgreq));
171 	zassert_equal(err, 0, "unexpected error %d");
172 	zassert_equal(mqtt_client->state, 0, "Wrong state");
173 	zassert_equal(evt_cb_data.called, 0, "Unexpected event");
174 	assert_msg_send(1, 5);
175 
176 	err = input(mqtt_client, connack, sizeof(connack));
177 	zassert_equal(err, 0, "unexpected error %d");
178 	zassert_equal(mqtt_client->state, 1, "Wrong state");
179 	zassert_equal(evt_cb_data.called, 1, "NO event");
180 	zassert_equal(evt_cb_data.last_evt.type, MQTT_SN_EVT_CONNECTED, "Wrong event");
181 	k_sleep(K_MSEC(10));
182 }
183 
ZTEST(mqtt_sn_client,test_mqtt_sn_publish_qos0)184 static ZTEST(mqtt_sn_client, test_mqtt_sn_publish_qos0)
185 {
186 	struct mqtt_sn_data data = MQTT_SN_DATA_STRING_LITERAL("Hello, World!");
187 	struct mqtt_sn_data topic = MQTT_SN_DATA_STRING_LITERAL("zephyr");
188 	/* registration ack with topic ID 0x1A1B, msg ID 0x0001, return code accepted */
189 	uint8_t regack[] = {7, 0x0B, 0x1A, 0x1B, 0x00, 0x01, 0};
190 	int err;
191 
192 	mqtt_sn_connect_no_will(mqtt_client);
193 	err = mqtt_sn_publish(mqtt_client, MQTT_SN_QOS_0, &topic, false, &data);
194 	zassert_equal(err, 0, "Unexpected error %d", err);
195 
196 	assert_msg_send(0, 0);
197 	k_sleep(K_MSEC(10));
198 	/* Expect a REGISTER to be sent */
199 	assert_msg_send(1, 12);
200 	err = input(mqtt_client, regack, sizeof(regack));
201 	zassert_equal(err, 0, "unexpected error %d");
202 	assert_msg_send(0, 0);
203 	k_sleep(K_MSEC(10));
204 	assert_msg_send(1, 20);
205 
206 	zassert_true(sys_slist_is_empty(&mqtt_client->publish), "Publish not empty");
207 	zassert_false(sys_slist_is_empty(&mqtt_client->topic), "Topic empty");
208 }
209 
210 ZTEST_SUITE(mqtt_sn_client, NULL, NULL, setup, NULL, NULL);
211