1 /*
2  * Copyright (c) 2018 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 /** @file mqtt_transport_socket_tls.h
8  *
9  * @brief Internal functions to handle transport over TLS socket.
10  */
11 
12 #include <zephyr/logging/log.h>
13 LOG_MODULE_REGISTER(net_mqtt_sock_tls, CONFIG_MQTT_LOG_LEVEL);
14 
15 #include <errno.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/net/mqtt.h>
18 
19 #include "mqtt_os.h"
20 
mqtt_client_tls_connect(struct mqtt_client * client)21 int mqtt_client_tls_connect(struct mqtt_client *client)
22 {
23 	const struct sockaddr *broker = client->broker;
24 	struct mqtt_sec_config *tls_config = &client->transport.tls.config;
25 	int ret;
26 
27 	client->transport.tls.sock = zsock_socket(broker->sa_family,
28 						  SOCK_STREAM, IPPROTO_TLS_1_2);
29 	if (client->transport.tls.sock < 0) {
30 		return -errno;
31 	}
32 
33 	NET_DBG("Created socket %d", client->transport.tls.sock);
34 
35 #if defined(CONFIG_SOCKS)
36 	if (client->transport.proxy.addrlen != 0) {
37 		ret = setsockopt(client->transport.tls.sock,
38 				 SOL_SOCKET, SO_SOCKS5,
39 				 &client->transport.proxy.addr,
40 				 client->transport.proxy.addrlen);
41 		if (ret < 0) {
42 			goto error;
43 		}
44 	}
45 #endif
46 	/* Set secure socket options. */
47 	ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS, TLS_PEER_VERIFY,
48 			       &tls_config->peer_verify,
49 			       sizeof(tls_config->peer_verify));
50 	if (ret < 0) {
51 		goto error;
52 	}
53 
54 	if (tls_config->cipher_list != NULL && tls_config->cipher_count > 0) {
55 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
56 				       TLS_CIPHERSUITE_LIST, tls_config->cipher_list,
57 				       sizeof(int) * tls_config->cipher_count);
58 		if (ret < 0) {
59 			goto error;
60 		}
61 	}
62 
63 	if (tls_config->sec_tag_list != NULL && tls_config->sec_tag_count > 0) {
64 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
65 				       TLS_SEC_TAG_LIST, tls_config->sec_tag_list,
66 				       sizeof(sec_tag_t) * tls_config->sec_tag_count);
67 		if (ret < 0) {
68 			goto error;
69 		}
70 	}
71 
72 #if defined(CONFIG_MQTT_LIB_TLS_USE_ALPN)
73 	if (tls_config->alpn_protocol_name_list != NULL &&
74 		tls_config->alpn_protocol_name_count > 0) {
75 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
76 				TLS_ALPN_LIST, tls_config->alpn_protocol_name_list,
77 				sizeof(const char *) * tls_config->alpn_protocol_name_count);
78 		if (ret < 0) {
79 			goto error;
80 		}
81 	}
82 
83 #endif
84 
85 	if (tls_config->hostname) {
86 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
87 				       TLS_HOSTNAME, tls_config->hostname,
88 				       strlen(tls_config->hostname) + 1);
89 		if (ret < 0) {
90 			goto error;
91 		}
92 	}
93 
94 	if (tls_config->cert_nocopy != TLS_CERT_NOCOPY_NONE) {
95 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
96 				       TLS_CERT_NOCOPY, &tls_config->cert_nocopy,
97 				       sizeof(tls_config->cert_nocopy));
98 		if (ret < 0) {
99 			goto error;
100 		}
101 	}
102 
103 	size_t peer_addr_size = sizeof(struct sockaddr_in6);
104 
105 	if (broker->sa_family == AF_INET) {
106 		peer_addr_size = sizeof(struct sockaddr_in);
107 	}
108 
109 	ret = zsock_connect(client->transport.tls.sock, client->broker,
110 			    peer_addr_size);
111 	if (ret < 0) {
112 		goto error;
113 	}
114 
115 	NET_DBG("Connect completed");
116 	return 0;
117 
118 error:
119 	(void) zsock_close(client->transport.tls.sock);
120 	return -errno;
121 }
122 
mqtt_client_tls_write(struct mqtt_client * client,const uint8_t * data,uint32_t datalen)123 int mqtt_client_tls_write(struct mqtt_client *client, const uint8_t *data,
124 			  uint32_t datalen)
125 {
126 	uint32_t offset = 0U;
127 	int ret;
128 
129 	while (offset < datalen) {
130 		ret = zsock_send(client->transport.tls.sock, data + offset,
131 				 datalen - offset, 0);
132 		if (ret < 0) {
133 			return -errno;
134 		}
135 
136 		offset += ret;
137 	}
138 
139 	return 0;
140 }
141 
mqtt_client_tls_write_msg(struct mqtt_client * client,const struct msghdr * message)142 int mqtt_client_tls_write_msg(struct mqtt_client *client,
143 			      const struct msghdr *message)
144 {
145 	int ret, i;
146 	size_t offset = 0;
147 	size_t total_len = 0;
148 
149 	for (i = 0; i < message->msg_iovlen; i++) {
150 		total_len += message->msg_iov[i].iov_len;
151 	}
152 
153 	while (offset < total_len) {
154 		ret = zsock_sendmsg(client->transport.tls.sock, message, 0);
155 		if (ret < 0) {
156 			return -errno;
157 		}
158 
159 		offset += ret;
160 		if (offset >= total_len) {
161 			break;
162 		}
163 
164 		/* Update msghdr for the next iteration. */
165 		for (i = 0; i < message->msg_iovlen; i++) {
166 			if (ret < message->msg_iov[i].iov_len) {
167 				message->msg_iov[i].iov_len -= ret;
168 				message->msg_iov[i].iov_base =
169 					(uint8_t *)message->msg_iov[i].iov_base + ret;
170 				break;
171 			}
172 
173 			ret -= message->msg_iov[i].iov_len;
174 			message->msg_iov[i].iov_len = 0;
175 		}
176 	}
177 
178 	return 0;
179 }
180 
mqtt_client_tls_read(struct mqtt_client * client,uint8_t * data,uint32_t buflen,bool shall_block)181 int mqtt_client_tls_read(struct mqtt_client *client, uint8_t *data, uint32_t buflen,
182 			 bool shall_block)
183 {
184 	int flags = 0;
185 	int ret;
186 
187 	if (!shall_block) {
188 		flags |= ZSOCK_MSG_DONTWAIT;
189 	}
190 
191 	ret = zsock_recv(client->transport.tls.sock, data, buflen, flags);
192 	if (ret < 0) {
193 		return -errno;
194 	}
195 
196 	return ret;
197 }
198 
mqtt_client_tls_disconnect(struct mqtt_client * client)199 int mqtt_client_tls_disconnect(struct mqtt_client *client)
200 {
201 	int ret;
202 
203 	NET_INFO("Closing socket %d", client->transport.tls.sock);
204 	ret = zsock_close(client->transport.tls.sock);
205 	if (ret < 0) {
206 		return -errno;
207 	}
208 
209 	return 0;
210 }
211