1 /*
2  * Copyright (c) 2018 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 /** @file mqtt_transport_socket_tls.h
8  *
9  * @brief Internal functions to handle transport over TLS socket.
10  */
11 
12 #include <zephyr/logging/log.h>
13 LOG_MODULE_REGISTER(net_mqtt_sock_tls, CONFIG_MQTT_LOG_LEVEL);
14 
15 #include <errno.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/net/mqtt.h>
18 
19 #include "mqtt_os.h"
20 
mqtt_client_tls_connect(struct mqtt_client * client)21 int mqtt_client_tls_connect(struct mqtt_client *client)
22 {
23 	const struct sockaddr *broker = client->broker;
24 	struct mqtt_sec_config *tls_config = &client->transport.tls.config;
25 	int ret;
26 
27 	client->transport.tls.sock = zsock_socket(broker->sa_family,
28 						  SOCK_STREAM, IPPROTO_TLS_1_2);
29 	if (client->transport.tls.sock < 0) {
30 		return -errno;
31 	}
32 
33 	NET_DBG("Created socket %d", client->transport.tls.sock);
34 
35 #if defined(CONFIG_SOCKS)
36 	if (client->transport.proxy.addrlen != 0) {
37 		ret = setsockopt(client->transport.tls.sock,
38 				 SOL_SOCKET, SO_SOCKS5,
39 				 &client->transport.proxy.addr,
40 				 client->transport.proxy.addrlen);
41 		if (ret < 0) {
42 			goto error;
43 		}
44 	}
45 #endif
46 	/* Set secure socket options. */
47 	ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS, TLS_PEER_VERIFY,
48 			       &tls_config->peer_verify,
49 			       sizeof(tls_config->peer_verify));
50 	if (ret < 0) {
51 		goto error;
52 	}
53 
54 	if (tls_config->cipher_list != NULL && tls_config->cipher_count > 0) {
55 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
56 				       TLS_CIPHERSUITE_LIST, tls_config->cipher_list,
57 				       sizeof(int) * tls_config->cipher_count);
58 		if (ret < 0) {
59 			goto error;
60 		}
61 	}
62 
63 	if (tls_config->sec_tag_list != NULL && tls_config->sec_tag_count > 0) {
64 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
65 				       TLS_SEC_TAG_LIST, tls_config->sec_tag_list,
66 				       sizeof(sec_tag_t) * tls_config->sec_tag_count);
67 		if (ret < 0) {
68 			goto error;
69 		}
70 	}
71 
72 	if (tls_config->hostname) {
73 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
74 				       TLS_HOSTNAME, tls_config->hostname,
75 				       strlen(tls_config->hostname) + 1);
76 		if (ret < 0) {
77 			goto error;
78 		}
79 	}
80 
81 	if (tls_config->cert_nocopy != TLS_CERT_NOCOPY_NONE) {
82 		ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
83 				       TLS_CERT_NOCOPY, &tls_config->cert_nocopy,
84 				       sizeof(tls_config->cert_nocopy));
85 		if (ret < 0) {
86 			goto error;
87 		}
88 	}
89 
90 	size_t peer_addr_size = sizeof(struct sockaddr_in6);
91 
92 	if (broker->sa_family == AF_INET) {
93 		peer_addr_size = sizeof(struct sockaddr_in);
94 	}
95 
96 	ret = zsock_connect(client->transport.tls.sock, client->broker,
97 			    peer_addr_size);
98 	if (ret < 0) {
99 		goto error;
100 	}
101 
102 	NET_DBG("Connect completed");
103 	return 0;
104 
105 error:
106 	(void) zsock_close(client->transport.tls.sock);
107 	return -errno;
108 }
109 
mqtt_client_tls_write(struct mqtt_client * client,const uint8_t * data,uint32_t datalen)110 int mqtt_client_tls_write(struct mqtt_client *client, const uint8_t *data,
111 			  uint32_t datalen)
112 {
113 	uint32_t offset = 0U;
114 	int ret;
115 
116 	while (offset < datalen) {
117 		ret = zsock_send(client->transport.tls.sock, data + offset,
118 				 datalen - offset, 0);
119 		if (ret < 0) {
120 			return -errno;
121 		}
122 
123 		offset += ret;
124 	}
125 
126 	return 0;
127 }
128 
mqtt_client_tls_write_msg(struct mqtt_client * client,const struct msghdr * message)129 int mqtt_client_tls_write_msg(struct mqtt_client *client,
130 			      const struct msghdr *message)
131 {
132 	int ret, i;
133 	size_t offset = 0;
134 	size_t total_len = 0;
135 
136 	for (i = 0; i < message->msg_iovlen; i++) {
137 		total_len += message->msg_iov[i].iov_len;
138 	}
139 
140 	while (offset < total_len) {
141 		ret = zsock_sendmsg(client->transport.tls.sock, message, 0);
142 		if (ret < 0) {
143 			return -errno;
144 		}
145 
146 		offset += ret;
147 		if (offset >= total_len) {
148 			break;
149 		}
150 
151 		/* Update msghdr for the next iteration. */
152 		for (i = 0; i < message->msg_iovlen; i++) {
153 			if (ret < message->msg_iov[i].iov_len) {
154 				message->msg_iov[i].iov_len -= ret;
155 				message->msg_iov[i].iov_base =
156 					(uint8_t *)message->msg_iov[i].iov_base + ret;
157 				break;
158 			}
159 
160 			ret -= message->msg_iov[i].iov_len;
161 			message->msg_iov[i].iov_len = 0;
162 		}
163 	}
164 
165 	return 0;
166 }
167 
mqtt_client_tls_read(struct mqtt_client * client,uint8_t * data,uint32_t buflen,bool shall_block)168 int mqtt_client_tls_read(struct mqtt_client *client, uint8_t *data, uint32_t buflen,
169 			 bool shall_block)
170 {
171 	int flags = 0;
172 	int ret;
173 
174 	if (!shall_block) {
175 		flags |= ZSOCK_MSG_DONTWAIT;
176 	}
177 
178 	ret = zsock_recv(client->transport.tls.sock, data, buflen, flags);
179 	if (ret < 0) {
180 		return -errno;
181 	}
182 
183 	return ret;
184 }
185 
mqtt_client_tls_disconnect(struct mqtt_client * client)186 int mqtt_client_tls_disconnect(struct mqtt_client *client)
187 {
188 	int ret;
189 
190 	NET_INFO("Closing socket %d", client->transport.tls.sock);
191 	ret = zsock_close(client->transport.tls.sock);
192 	if (ret < 0) {
193 		return -errno;
194 	}
195 
196 	return 0;
197 }
198