1 /*
2 * Copyright (c) 2018 Nordic Semiconductor ASA
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 /** @file mqtt_transport_socket_tls.h
8 *
9 * @brief Internal functions to handle transport over TLS socket.
10 */
11
12 #include <zephyr/logging/log.h>
13 LOG_MODULE_REGISTER(net_mqtt_sock_tls, CONFIG_MQTT_LOG_LEVEL);
14
15 #include <errno.h>
16 #include <zephyr/net/socket.h>
17 #include <zephyr/net/mqtt.h>
18
19 #include "mqtt_os.h"
20
mqtt_client_tls_connect(struct mqtt_client * client)21 int mqtt_client_tls_connect(struct mqtt_client *client)
22 {
23 const struct sockaddr *broker = client->broker;
24 struct mqtt_sec_config *tls_config = &client->transport.tls.config;
25 int ret;
26
27 client->transport.tls.sock = zsock_socket(broker->sa_family,
28 SOCK_STREAM, IPPROTO_TLS_1_2);
29 if (client->transport.tls.sock < 0) {
30 return -errno;
31 }
32
33 NET_DBG("Created socket %d", client->transport.tls.sock);
34
35 #if defined(CONFIG_SOCKS)
36 if (client->transport.proxy.addrlen != 0) {
37 ret = setsockopt(client->transport.tls.sock,
38 SOL_SOCKET, SO_SOCKS5,
39 &client->transport.proxy.addr,
40 client->transport.proxy.addrlen);
41 if (ret < 0) {
42 goto error;
43 }
44 }
45 #endif
46 /* Set secure socket options. */
47 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS, TLS_PEER_VERIFY,
48 &tls_config->peer_verify,
49 sizeof(tls_config->peer_verify));
50 if (ret < 0) {
51 goto error;
52 }
53
54 if (tls_config->cipher_list != NULL && tls_config->cipher_count > 0) {
55 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
56 TLS_CIPHERSUITE_LIST, tls_config->cipher_list,
57 sizeof(int) * tls_config->cipher_count);
58 if (ret < 0) {
59 goto error;
60 }
61 }
62
63 if (tls_config->sec_tag_list != NULL && tls_config->sec_tag_count > 0) {
64 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
65 TLS_SEC_TAG_LIST, tls_config->sec_tag_list,
66 sizeof(sec_tag_t) * tls_config->sec_tag_count);
67 if (ret < 0) {
68 goto error;
69 }
70 }
71
72 #if defined(CONFIG_MQTT_LIB_TLS_USE_ALPN)
73 if (tls_config->alpn_protocol_name_list != NULL &&
74 tls_config->alpn_protocol_name_count > 0) {
75 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
76 TLS_ALPN_LIST, tls_config->alpn_protocol_name_list,
77 sizeof(const char *) * tls_config->alpn_protocol_name_count);
78 if (ret < 0) {
79 goto error;
80 }
81 }
82
83 #endif
84
85 if (tls_config->hostname) {
86 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
87 TLS_HOSTNAME, tls_config->hostname,
88 strlen(tls_config->hostname) + 1);
89 if (ret < 0) {
90 goto error;
91 }
92 }
93
94 if (tls_config->cert_nocopy != TLS_CERT_NOCOPY_NONE) {
95 ret = zsock_setsockopt(client->transport.tls.sock, SOL_TLS,
96 TLS_CERT_NOCOPY, &tls_config->cert_nocopy,
97 sizeof(tls_config->cert_nocopy));
98 if (ret < 0) {
99 goto error;
100 }
101 }
102
103 size_t peer_addr_size = sizeof(struct sockaddr_in6);
104
105 if (broker->sa_family == AF_INET) {
106 peer_addr_size = sizeof(struct sockaddr_in);
107 }
108
109 ret = zsock_connect(client->transport.tls.sock, client->broker,
110 peer_addr_size);
111 if (ret < 0) {
112 goto error;
113 }
114
115 NET_DBG("Connect completed");
116 return 0;
117
118 error:
119 (void) zsock_close(client->transport.tls.sock);
120 return -errno;
121 }
122
mqtt_client_tls_write(struct mqtt_client * client,const uint8_t * data,uint32_t datalen)123 int mqtt_client_tls_write(struct mqtt_client *client, const uint8_t *data,
124 uint32_t datalen)
125 {
126 uint32_t offset = 0U;
127 int ret;
128
129 while (offset < datalen) {
130 ret = zsock_send(client->transport.tls.sock, data + offset,
131 datalen - offset, 0);
132 if (ret < 0) {
133 return -errno;
134 }
135
136 offset += ret;
137 }
138
139 return 0;
140 }
141
mqtt_client_tls_write_msg(struct mqtt_client * client,const struct msghdr * message)142 int mqtt_client_tls_write_msg(struct mqtt_client *client,
143 const struct msghdr *message)
144 {
145 int ret, i;
146 size_t offset = 0;
147 size_t total_len = 0;
148
149 for (i = 0; i < message->msg_iovlen; i++) {
150 total_len += message->msg_iov[i].iov_len;
151 }
152
153 while (offset < total_len) {
154 ret = zsock_sendmsg(client->transport.tls.sock, message, 0);
155 if (ret < 0) {
156 return -errno;
157 }
158
159 offset += ret;
160 if (offset >= total_len) {
161 break;
162 }
163
164 /* Update msghdr for the next iteration. */
165 for (i = 0; i < message->msg_iovlen; i++) {
166 if (ret < message->msg_iov[i].iov_len) {
167 message->msg_iov[i].iov_len -= ret;
168 message->msg_iov[i].iov_base =
169 (uint8_t *)message->msg_iov[i].iov_base + ret;
170 break;
171 }
172
173 ret -= message->msg_iov[i].iov_len;
174 message->msg_iov[i].iov_len = 0;
175 }
176 }
177
178 return 0;
179 }
180
mqtt_client_tls_read(struct mqtt_client * client,uint8_t * data,uint32_t buflen,bool shall_block)181 int mqtt_client_tls_read(struct mqtt_client *client, uint8_t *data, uint32_t buflen,
182 bool shall_block)
183 {
184 int flags = 0;
185 int ret;
186
187 if (!shall_block) {
188 flags |= ZSOCK_MSG_DONTWAIT;
189 }
190
191 ret = zsock_recv(client->transport.tls.sock, data, buflen, flags);
192 if (ret < 0) {
193 return -errno;
194 }
195
196 return ret;
197 }
198
mqtt_client_tls_disconnect(struct mqtt_client * client)199 int mqtt_client_tls_disconnect(struct mqtt_client *client)
200 {
201 int ret;
202
203 NET_INFO("Closing socket %d", client->transport.tls.sock);
204 ret = zsock_close(client->transport.tls.sock);
205 if (ret < 0) {
206 return -errno;
207 }
208
209 return 0;
210 }
211