1 /*
2  * Copyright (c) 2018 Nordic Semiconductor ASA
3  * Copyright (c) 2019 Intel Corporation
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 /** @file mqtt_transport_websocket.c
9  *
10  * @brief Internal functions to handle transport over Websocket.
11  */
12 
13 #include <zephyr/logging/log.h>
14 LOG_MODULE_REGISTER(net_mqtt_websocket, CONFIG_MQTT_LOG_LEVEL);
15 
16 #include <errno.h>
17 #include <zephyr/net/socket.h>
18 #include <zephyr/net/mqtt.h>
19 #include <zephyr/net/websocket.h>
20 
21 #include "mqtt_os.h"
22 #include "mqtt_transport.h"
23 
mqtt_client_websocket_connect(struct mqtt_client * client)24 int mqtt_client_websocket_connect(struct mqtt_client *client)
25 {
26 	const char *extra_headers[] = {
27 		"Sec-WebSocket-Protocol: mqtt\r\n",
28 		NULL
29 	};
30 	int transport_sock;
31 	int ret;
32 
33 	if (client->transport.type == MQTT_TRANSPORT_NON_SECURE_WEBSOCKET) {
34 		ret = mqtt_client_tcp_connect(client);
35 		if (ret < 0) {
36 			return ret;
37 		}
38 
39 		transport_sock = client->transport.tcp.sock;
40 	}
41 #if defined(CONFIG_MQTT_LIB_TLS)
42 	else if (client->transport.type == MQTT_TRANSPORT_SECURE_WEBSOCKET) {
43 		ret = mqtt_client_tls_connect(client);
44 		if (ret < 0) {
45 			return ret;
46 		}
47 
48 		transport_sock = client->transport.tls.sock;
49 	}
50 #endif
51 	else {
52 		return -EINVAL;
53 	}
54 
55 	if (client->transport.websocket.config.url == NULL) {
56 		client->transport.websocket.config.url = "/mqtt";
57 	}
58 
59 	if (client->transport.websocket.config.host == NULL) {
60 		client->transport.websocket.config.host = "localhost";
61 	}
62 
63 	/* If application needs to set some extra header options, then
64 	 * it can set the optional_headers_cb. In this case the app
65 	 * will need to also send "Sec-WebSocket-Protocol: mqtt\r\n"
66 	 * field as the optional_headers field is ignored if the callback
67 	 * is set.
68 	 */
69 	client->transport.websocket.config.optional_headers = extra_headers;
70 
71 	client->transport.websocket.sock = websocket_connect(
72 			transport_sock,
73 			&client->transport.websocket.config,
74 			client->transport.websocket.timeout,
75 			NULL);
76 	if (client->transport.websocket.sock < 0) {
77 		NET_ERR("Websocket connect failed (%d)",
78 			 client->transport.websocket.sock);
79 
80 		(void)close(transport_sock);
81 		return client->transport.websocket.sock;
82 	}
83 
84 	NET_DBG("Connect completed");
85 
86 	return 0;
87 }
88 
mqtt_client_websocket_write(struct mqtt_client * client,const uint8_t * data,uint32_t datalen)89 int mqtt_client_websocket_write(struct mqtt_client *client, const uint8_t *data,
90 				uint32_t datalen)
91 {
92 	uint32_t offset = 0U;
93 	int ret;
94 
95 	while (offset < datalen) {
96 		ret = websocket_send_msg(client->transport.websocket.sock,
97 					 data + offset, datalen - offset,
98 					 WEBSOCKET_OPCODE_DATA_BINARY,
99 					 true, true, SYS_FOREVER_MS);
100 		if (ret < 0) {
101 			return -errno;
102 		}
103 
104 		offset += ret;
105 	}
106 
107 	return 0;
108 }
109 
mqtt_client_websocket_write_msg(struct mqtt_client * client,const struct msghdr * message)110 int mqtt_client_websocket_write_msg(struct mqtt_client *client,
111 				    const struct msghdr *message)
112 {
113 	enum websocket_opcode opcode = WEBSOCKET_OPCODE_DATA_BINARY;
114 	bool final = false;
115 	ssize_t len;
116 	ssize_t ret;
117 	int i;
118 
119 	len = 0;
120 	for (i = 0; i < message->msg_iovlen; i++) {
121 		if (i == message->msg_iovlen - 1) {
122 			final = true;
123 		}
124 
125 		ret = websocket_send_msg(client->transport.websocket.sock,
126 					 message->msg_iov[i].iov_base,
127 					 message->msg_iov[i].iov_len, opcode,
128 					 true, final, SYS_FOREVER_MS);
129 		if (ret < 0) {
130 			return ret;
131 		}
132 
133 		opcode = WEBSOCKET_OPCODE_CONTINUE;
134 		len += ret;
135 	}
136 
137 	return len;
138 }
139 
mqtt_client_websocket_read(struct mqtt_client * client,uint8_t * data,uint32_t buflen,bool shall_block)140 int mqtt_client_websocket_read(struct mqtt_client *client, uint8_t *data,
141 			       uint32_t buflen, bool shall_block)
142 {
143 	int32_t timeout = SYS_FOREVER_MS;
144 	uint32_t message_type = 0U;
145 	int ret;
146 
147 	if (!shall_block) {
148 		timeout = 0;
149 	}
150 
151 	ret = websocket_recv_msg(client->transport.websocket.sock,
152 				 data, buflen, &message_type, NULL, timeout);
153 	if (ret >= 0 && message_type > 0) {
154 		if (message_type & WEBSOCKET_FLAG_CLOSE) {
155 			return 0;
156 		}
157 
158 		if ((ret == 0) || !(message_type & WEBSOCKET_FLAG_BINARY)) {
159 			return -EAGAIN;
160 		}
161 	}
162 	if (ret == -ENOTCONN) {
163 		ret = 0;
164 	}
165 
166 	return ret;
167 }
168 
mqtt_client_websocket_disconnect(struct mqtt_client * client)169 int mqtt_client_websocket_disconnect(struct mqtt_client *client)
170 {
171 	int ret;
172 
173 	NET_INFO("Closing socket %d", client->transport.websocket.sock);
174 
175 	ret = websocket_disconnect(client->transport.websocket.sock);
176 	if (ret < 0) {
177 		NET_ERR("Websocket disconnect failed (%d)", ret);
178 		return ret;
179 	}
180 
181 	if (client->transport.type == MQTT_TRANSPORT_NON_SECURE_WEBSOCKET) {
182 		ret = mqtt_client_tcp_disconnect(client);
183 	}
184 #if defined(CONFIG_MQTT_LIB_TLS)
185 	else if (client->transport.type == MQTT_TRANSPORT_SECURE_WEBSOCKET) {
186 		ret = mqtt_client_tls_disconnect(client);
187 	}
188 #endif
189 
190 	return ret;
191 }
192