1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 * Copyright (c) 2024 Nordic Semiconductor ASA
4 */
5
6 #include <zephyr/logging/log.h>
7 LOG_MODULE_REGISTER(tls_configuration_sample, LOG_LEVEL_INF);
8
9 #include <zephyr/kernel.h>
10 #include <errno.h>
11 #include <stdio.h>
12
13 #include <zephyr/posix/sys/eventfd.h>
14
15 #include <zephyr/net/socket.h>
16 #include <zephyr/net/tls_credentials.h>
17 #include <zephyr/net/net_if.h>
18 #include <zephyr/sys/util.h>
19
20 /* This include is required for the definition of the Mbed TLS internal symbol
21 * MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED.
22 */
23 #include <mbedtls/ssl_ciphersuites.h>
24
25 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
26 static const unsigned char psk[] = { 0x01, 0x02, 0x03, 0x04, 0x05 };
27 static const char psk_id[] = "PSK_identity";
28 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED */
29
30 /* Following certificates (*.inc files) are:
31 * - generated from "create-certs.sh" script
32 * - converted in C array shape in the CMakeList file
33 */
34 #if defined(CONFIG_PSA_WANT_ALG_RSA_PKCS1V15_SIGN) || defined(CONFIG_PSA_WANT_ALG_RSA_PSS)
35 #define USE_CERTIFICATE
36 static const unsigned char certificate[] = {
37 #include "rsa.crt.inc"
38 };
39 #elif defined(CONFIG_PSA_WANT_ALG_ECDSA)
40 #define USE_CERTIFICATE
41 static const unsigned char certificate[] = {
42 #include "ec.crt.inc"
43 };
44 #endif
45
46 #define APP_BANNER "TLS socket configuration sample"
47
48 #define INVALID_SOCKET (-1)
49
50 enum {
51 _PLACEHOLDER_TAG_ = 0,
52 #if defined(USE_CERTIFICATE)
53 CA_CERTIFICATE_TAG,
54 #endif
55 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
56 PSK_TAG,
57 #endif
58 };
59
60 static int socket_fd = INVALID_SOCKET;
61 static struct pollfd fds[1];
62
63 /* Keep the new line because openssl uses that to start processing the incoming data */
64 #define TEST_STRING "hello world\n"
65 static uint8_t test_buf[sizeof(TEST_STRING)];
66
wait_for_event(void)67 static int wait_for_event(void)
68 {
69 int ret;
70
71 /* Wait for event on any socket used. Once event occurs,
72 * we'll check them all.
73 */
74 ret = poll(fds, ARRAY_SIZE(fds), -1);
75 if (ret < 0) {
76 LOG_ERR("Error in poll (%d)", errno);
77 return ret;
78 }
79
80 return 0;
81 }
82
create_socket(void)83 static int create_socket(void)
84 {
85 int ret = 0;
86 struct sockaddr_in addr;
87
88 addr.sin_family = AF_INET;
89 addr.sin_port = htons(CONFIG_SERVER_PORT);
90 inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);
91
92 #if defined(CONFIG_MBEDTLS_TLS_VERSION_1_3)
93 socket_fd = socket(addr.sin_family, SOCK_STREAM, IPPROTO_TLS_1_3);
94 #else
95 socket_fd = socket(addr.sin_family, SOCK_STREAM, IPPROTO_TLS_1_2);
96 #endif
97 if (socket_fd < 0) {
98 LOG_ERR("Failed to create TLS socket (%d)", errno);
99 return -errno;
100 }
101
102 sec_tag_t sec_tag_list[] = {
103 #if defined(USE_CERTIFICATE)
104 CA_CERTIFICATE_TAG,
105 #endif
106 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
107 PSK_TAG,
108 #endif
109 };
110
111 ret = setsockopt(socket_fd, SOL_TLS, TLS_SEC_TAG_LIST,
112 sec_tag_list, sizeof(sec_tag_list));
113 if (ret < 0) {
114 LOG_ERR("Failed to set TLS_SEC_TAG_LIST option (%d)", errno);
115 return -errno;
116 }
117
118 /* HOSTNAME is only required for key exchanges that use a certificate. */
119 #if defined(USE_CERTIFICATE)
120 ret = setsockopt(socket_fd, SOL_TLS, TLS_HOSTNAME,
121 "localhost", sizeof("localhost"));
122 if (ret < 0) {
123 LOG_ERR("Failed to set TLS_HOSTNAME option (%d)", errno);
124 return -errno;
125 }
126 #endif
127
128 ret = connect(socket_fd, (struct sockaddr *) &addr, sizeof(addr));
129 if (ret < 0) {
130 LOG_ERR("Cannot connect to TCP remote (%d)", errno);
131 return -errno;
132 }
133
134 /* Prepare file descriptor for polling */
135 fds[0].fd = socket_fd;
136 fds[0].events = POLLIN;
137
138 return ret;
139 }
140
close_socket(void)141 void close_socket(void)
142 {
143 if (socket_fd != INVALID_SOCKET) {
144 close(socket_fd);
145 }
146 }
147
setup_credentials(void)148 static int setup_credentials(void)
149 {
150 int err;
151
152 #if defined(USE_CERTIFICATE)
153 err = tls_credential_add(CA_CERTIFICATE_TAG,
154 TLS_CREDENTIAL_CA_CERTIFICATE,
155 certificate,
156 sizeof(certificate));
157 if (err < 0) {
158 LOG_ERR("Failed to register certificate: %d", err);
159 return err;
160 }
161 #endif
162
163 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
164 err = tls_credential_add(PSK_TAG,
165 TLS_CREDENTIAL_PSK,
166 psk,
167 sizeof(psk));
168 if (err < 0) {
169 LOG_ERR("Failed to register PSK: %d", err);
170 }
171 err = tls_credential_add(PSK_TAG,
172 TLS_CREDENTIAL_PSK_ID,
173 psk_id,
174 sizeof(psk_id) - 1);
175 if (err < 0) {
176 LOG_ERR("Failed to register PSK ID: %d", err);
177 }
178 #endif
179
180 return 0;
181 }
182
main(void)183 int main(void)
184 {
185 int ret;
186 size_t data_len;
187
188 LOG_INF(APP_BANNER);
189
190 setup_credentials();
191
192 ret = create_socket();
193 if (ret < 0) {
194 LOG_ERR("Socket creation failed (%d)", ret);
195 goto exit;
196 }
197
198 memcpy(test_buf, TEST_STRING, sizeof(TEST_STRING));
199 /* The -1 here is because sizeof() accounts for "\0" but that's not
200 * needed for socket functions send/recv.
201 */
202 data_len = sizeof(TEST_STRING) - 1;
203
204 /* OpenSSL s_server has only the "-rev" option as echo-like behavior
205 * which echoes back the data that we send it in reversed order. So
206 * in the following we send the test buffer twice (in the 1st iteration
207 * it will contain the original TEST_STRING, whereas in the 2nd one
208 * it will contain TEST_STRING reversed) so that in the end we can
209 * just memcmp() it against the original TEST_STRING.
210 */
211 for (int i = 0; i < 2; i++) {
212 LOG_DBG("Send: %s", test_buf);
213 ret = send(socket_fd, test_buf, data_len, 0);
214 if (ret < 0) {
215 LOG_ERR("Error sending test string (%d)", errno);
216 goto exit;
217 }
218
219 memset(test_buf, 0, sizeof(test_buf));
220
221 wait_for_event();
222
223 ret = recv(socket_fd, test_buf, data_len, MSG_WAITALL);
224 if (ret == 0) {
225 LOG_ERR("Server terminated unexpectedly");
226 ret = -EIO;
227 goto exit;
228 } else if (ret < 0) {
229 LOG_ERR("Error receiving data (%d)", errno);
230 goto exit;
231 }
232 if (ret != data_len) {
233 LOG_ERR("Sent %d bytes, but received %d", data_len, ret);
234 ret = -EINVAL;
235 goto exit;
236 }
237 LOG_DBG("Received: %s", test_buf);
238 }
239
240 ret = memcmp(TEST_STRING, test_buf, data_len);
241 if (ret != 0) {
242 LOG_ERR("Received data does not match with TEST_STRING");
243 LOG_HEXDUMP_ERR(test_buf, data_len, "Received:");
244 LOG_HEXDUMP_ERR(TEST_STRING, data_len, "Expected:");
245 ret = -EINVAL;
246 }
247
248 exit:
249 LOG_INF("Test %s", (ret < 0) ? "FAILED" : "PASSED");
250
251 close_socket();
252
253 return 0;
254 }
255