1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 * Copyright (c) 2024 Nordic Semiconductor ASA
4 */
5
6 #include <zephyr/logging/log.h>
7 LOG_MODULE_REGISTER(tls_configuration_sample, LOG_LEVEL_INF);
8
9 #include <zephyr/kernel.h>
10 #include <errno.h>
11 #include <stdio.h>
12
13 #include <zephyr/posix/sys/eventfd.h>
14 #include <zephyr/posix/poll.h>
15 #include <zephyr/posix/arpa/inet.h>
16 #include <zephyr/posix/unistd.h>
17 #include <zephyr/posix/sys/socket.h>
18
19 #include <zephyr/net/socket.h>
20 #include <zephyr/net/tls_credentials.h>
21 #include <zephyr/net/net_if.h>
22 #include <zephyr/sys/util.h>
23
24 /* This include is required for the definition of the Mbed TLS internal symbol
25 * MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED.
26 */
27 #include <mbedtls/ssl_ciphersuites.h>
28
29 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
30 static const unsigned char psk[] = { 0x01, 0x02, 0x03, 0x04, 0x05 };
31 static const char psk_id[] = "PSK_identity";
32 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED */
33
34 /* Following certificates (*.inc files) are:
35 * - generated from "create-certs.sh" script
36 * - converted in C array shape in the CMakeList file
37 */
38 #if defined(CONFIG_PSA_WANT_ALG_RSA_PKCS1V15_SIGN) || defined(CONFIG_PSA_WANT_ALG_RSA_PSS)
39 #define USE_CERTIFICATE
40 static const unsigned char certificate[] = {
41 #include "rsa.crt.inc"
42 };
43 #elif defined(CONFIG_PSA_WANT_ALG_ECDSA)
44 #define USE_CERTIFICATE
45 static const unsigned char certificate[] = {
46 #include "ec.crt.inc"
47 };
48 #endif
49
50 #define APP_BANNER "TLS socket configuration sample"
51
52 #define INVALID_SOCKET (-1)
53
54 enum {
55 _PLACEHOLDER_TAG_ = 0,
56 #if defined(USE_CERTIFICATE)
57 CA_CERTIFICATE_TAG,
58 #endif
59 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
60 PSK_TAG,
61 #endif
62 };
63
64 static int socket_fd = INVALID_SOCKET;
65 static struct pollfd fds[1];
66
67 /* Keep the new line because openssl uses that to start processing the incoming data */
68 #define TEST_STRING "hello world\n"
69 static uint8_t test_buf[sizeof(TEST_STRING)];
70
wait_for_event(void)71 static int wait_for_event(void)
72 {
73 int ret;
74
75 /* Wait for event on any socket used. Once event occurs,
76 * we'll check them all.
77 */
78 ret = poll(fds, ARRAY_SIZE(fds), -1);
79 if (ret < 0) {
80 LOG_ERR("Error in poll (%d)", errno);
81 return ret;
82 }
83
84 return 0;
85 }
86
create_socket(void)87 static int create_socket(void)
88 {
89 int ret = 0;
90 struct net_sockaddr_in addr;
91
92 addr.sin_family = NET_AF_INET;
93 addr.sin_port = net_htons(CONFIG_SERVER_PORT);
94 inet_pton(NET_AF_INET, "127.0.0.1", &addr.sin_addr);
95
96 #if defined(CONFIG_MBEDTLS_SSL_PROTO_TLS1_3)
97 socket_fd = socket(addr.sin_family, NET_SOCK_STREAM, IPPROTO_TLS_1_3);
98 #else
99 socket_fd = socket(addr.sin_family, NET_SOCK_STREAM, IPPROTO_TLS_1_2);
100 #endif
101 if (socket_fd < 0) {
102 LOG_ERR("Failed to create TLS socket (%d)", errno);
103 return -errno;
104 }
105
106 sec_tag_t sec_tag_list[] = {
107 #if defined(USE_CERTIFICATE)
108 CA_CERTIFICATE_TAG,
109 #endif
110 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
111 PSK_TAG,
112 #endif
113 };
114
115 ret = setsockopt(socket_fd, SOL_TLS, TLS_SEC_TAG_LIST,
116 sec_tag_list, sizeof(sec_tag_list));
117 if (ret < 0) {
118 LOG_ERR("Failed to set TLS_SEC_TAG_LIST option (%d)", errno);
119 return -errno;
120 }
121
122 /* HOSTNAME is only required for key exchanges that use a certificate. */
123 #if defined(USE_CERTIFICATE)
124 ret = setsockopt(socket_fd, SOL_TLS, TLS_HOSTNAME,
125 "localhost", sizeof("localhost"));
126 if (ret < 0) {
127 LOG_ERR("Failed to set TLS_HOSTNAME option (%d)", errno);
128 return -errno;
129 }
130 #endif
131
132 ret = connect(socket_fd, (struct net_sockaddr *) &addr, sizeof(addr));
133 if (ret < 0) {
134 LOG_ERR("Cannot connect to TCP remote (%d)", errno);
135 return -errno;
136 }
137
138 /* Prepare file descriptor for polling */
139 fds[0].fd = socket_fd;
140 fds[0].events = POLLIN;
141
142 return ret;
143 }
144
close_socket(void)145 void close_socket(void)
146 {
147 if (socket_fd != INVALID_SOCKET) {
148 close(socket_fd);
149 }
150 }
151
setup_credentials(void)152 static int setup_credentials(void)
153 {
154 __maybe_unused int err;
155
156 #if defined(USE_CERTIFICATE)
157 err = tls_credential_add(CA_CERTIFICATE_TAG,
158 TLS_CREDENTIAL_CA_CERTIFICATE,
159 certificate,
160 sizeof(certificate));
161 if (err < 0) {
162 LOG_ERR("Failed to register certificate: %d", err);
163 return err;
164 }
165 #endif
166
167 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
168 err = tls_credential_add(PSK_TAG,
169 TLS_CREDENTIAL_PSK,
170 psk,
171 sizeof(psk));
172 if (err < 0) {
173 LOG_ERR("Failed to register PSK: %d", err);
174 }
175 err = tls_credential_add(PSK_TAG,
176 TLS_CREDENTIAL_PSK_ID,
177 psk_id,
178 sizeof(psk_id) - 1);
179 if (err < 0) {
180 LOG_ERR("Failed to register PSK ID: %d", err);
181 }
182 #endif
183
184 return 0;
185 }
186
main(void)187 int main(void)
188 {
189 int ret;
190 size_t data_len;
191
192 LOG_INF(APP_BANNER);
193
194 setup_credentials();
195
196 ret = create_socket();
197 if (ret < 0) {
198 LOG_ERR("Socket creation failed (%d)", ret);
199 goto exit;
200 }
201
202 memcpy(test_buf, TEST_STRING, sizeof(TEST_STRING));
203 /* The -1 here is because sizeof() accounts for "\0" but that's not
204 * needed for socket functions send/recv.
205 */
206 data_len = sizeof(TEST_STRING) - 1;
207
208 /* OpenSSL s_server has only the "-rev" option as echo-like behavior
209 * which echoes back the data that we send it in reversed order. So
210 * in the following we send the test buffer twice (in the 1st iteration
211 * it will contain the original TEST_STRING, whereas in the 2nd one
212 * it will contain TEST_STRING reversed) so that in the end we can
213 * just memcmp() it against the original TEST_STRING.
214 */
215 for (int i = 0; i < 2; i++) {
216 LOG_DBG("Send: %s", test_buf);
217 ret = send(socket_fd, test_buf, data_len, 0);
218 if (ret < 0) {
219 LOG_ERR("Error sending test string (%d)", errno);
220 goto exit;
221 }
222
223 memset(test_buf, 0, sizeof(test_buf));
224
225 wait_for_event();
226
227 ret = recv(socket_fd, test_buf, data_len, MSG_WAITALL);
228 if (ret == 0) {
229 LOG_ERR("Server terminated unexpectedly");
230 ret = -EIO;
231 goto exit;
232 } else if (ret < 0) {
233 LOG_ERR("Error receiving data (%d)", errno);
234 goto exit;
235 }
236 if (ret != data_len) {
237 LOG_ERR("Sent %d bytes, but received %d", data_len, ret);
238 ret = -EINVAL;
239 goto exit;
240 }
241 LOG_DBG("Received: %s", test_buf);
242 }
243
244 ret = memcmp(TEST_STRING, test_buf, data_len);
245 if (ret != 0) {
246 LOG_ERR("Received data does not match with TEST_STRING");
247 LOG_HEXDUMP_ERR(test_buf, data_len, "Received:");
248 LOG_HEXDUMP_ERR(TEST_STRING, data_len, "Expected:");
249 ret = -EINVAL;
250 }
251
252 exit:
253 LOG_INF("Test %s", (ret < 0) ? "FAILED" : "PASSED");
254
255 close_socket();
256
257 return 0;
258 }
259