1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  * Copyright (c) 2024 Nordic Semiconductor ASA
4  */
5 
6 #include <zephyr/logging/log.h>
7 LOG_MODULE_REGISTER(tls_configuration_sample, LOG_LEVEL_INF);
8 
9 #include <zephyr/kernel.h>
10 #include <errno.h>
11 #include <stdio.h>
12 
13 #include <zephyr/posix/sys/eventfd.h>
14 #include <zephyr/posix/poll.h>
15 #include <zephyr/posix/arpa/inet.h>
16 #include <zephyr/posix/unistd.h>
17 #include <zephyr/posix/sys/socket.h>
18 
19 #include <zephyr/net/socket.h>
20 #include <zephyr/net/tls_credentials.h>
21 #include <zephyr/net/net_if.h>
22 #include <zephyr/sys/util.h>
23 
24 /* This include is required for the definition of the Mbed TLS internal symbol
25  * MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED.
26  */
27 #include <mbedtls/ssl_ciphersuites.h>
28 
29 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
30 static const unsigned char psk[] = { 0x01, 0x02, 0x03, 0x04, 0x05 };
31 static const char psk_id[] = "PSK_identity";
32 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED */
33 
34 /* Following certificates (*.inc files) are:
35  * - generated from "create-certs.sh" script
36  * - converted in C array shape in the CMakeList file
37  */
38 #if defined(CONFIG_PSA_WANT_ALG_RSA_PKCS1V15_SIGN) || defined(CONFIG_PSA_WANT_ALG_RSA_PSS)
39 #define USE_CERTIFICATE
40 static const unsigned char certificate[] = {
41 #include "rsa.crt.inc"
42 };
43 #elif defined(CONFIG_PSA_WANT_ALG_ECDSA)
44 #define USE_CERTIFICATE
45 static const unsigned char certificate[] = {
46 #include "ec.crt.inc"
47 };
48 #endif
49 
50 #define APP_BANNER "TLS socket configuration sample"
51 
52 #define INVALID_SOCKET (-1)
53 
54 enum {
55 	_PLACEHOLDER_TAG_ = 0,
56 #if defined(USE_CERTIFICATE)
57 	CA_CERTIFICATE_TAG,
58 #endif
59 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
60 	PSK_TAG,
61 #endif
62 };
63 
64 static int socket_fd = INVALID_SOCKET;
65 static struct pollfd fds[1];
66 
67 /* Keep the new line because openssl uses that to start processing the incoming data */
68 #define TEST_STRING "hello world\n"
69 static uint8_t test_buf[sizeof(TEST_STRING)];
70 
wait_for_event(void)71 static int wait_for_event(void)
72 {
73 	int ret;
74 
75 	/* Wait for event on any socket used. Once event occurs,
76 	 * we'll check them all.
77 	 */
78 	ret = poll(fds, ARRAY_SIZE(fds), -1);
79 	if (ret < 0) {
80 		LOG_ERR("Error in poll (%d)", errno);
81 		return ret;
82 	}
83 
84 	return 0;
85 }
86 
create_socket(void)87 static int create_socket(void)
88 {
89 	int ret = 0;
90 	struct net_sockaddr_in addr;
91 
92 	addr.sin_family = NET_AF_INET;
93 	addr.sin_port = net_htons(CONFIG_SERVER_PORT);
94 	inet_pton(NET_AF_INET, "127.0.0.1", &addr.sin_addr);
95 
96 #if defined(CONFIG_MBEDTLS_SSL_PROTO_TLS1_3)
97 	socket_fd = socket(addr.sin_family, NET_SOCK_STREAM, IPPROTO_TLS_1_3);
98 #else
99 	socket_fd = socket(addr.sin_family, NET_SOCK_STREAM, IPPROTO_TLS_1_2);
100 #endif
101 	if (socket_fd < 0) {
102 		LOG_ERR("Failed to create TLS socket (%d)", errno);
103 		return -errno;
104 	}
105 
106 	sec_tag_t sec_tag_list[] = {
107 #if defined(USE_CERTIFICATE)
108 		CA_CERTIFICATE_TAG,
109 #endif
110 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
111 		PSK_TAG,
112 #endif
113 	};
114 
115 	ret = setsockopt(socket_fd, SOL_TLS, TLS_SEC_TAG_LIST,
116 			sec_tag_list, sizeof(sec_tag_list));
117 	if (ret < 0) {
118 		LOG_ERR("Failed to set TLS_SEC_TAG_LIST option (%d)", errno);
119 		return -errno;
120 	}
121 
122 	/* HOSTNAME is only required for key exchanges that use a certificate. */
123 #if defined(USE_CERTIFICATE)
124 	ret = setsockopt(socket_fd, SOL_TLS, TLS_HOSTNAME,
125 			 "localhost", sizeof("localhost"));
126 	if (ret < 0) {
127 		LOG_ERR("Failed to set TLS_HOSTNAME option (%d)", errno);
128 		return -errno;
129 	}
130 #endif
131 
132 	ret = connect(socket_fd, (struct net_sockaddr *) &addr, sizeof(addr));
133 	if (ret < 0) {
134 		LOG_ERR("Cannot connect to TCP remote (%d)", errno);
135 		return -errno;
136 	}
137 
138 	/* Prepare file descriptor for polling */
139 	fds[0].fd = socket_fd;
140 	fds[0].events = POLLIN;
141 
142 	return ret;
143 }
144 
close_socket(void)145 void close_socket(void)
146 {
147 	if (socket_fd != INVALID_SOCKET) {
148 		close(socket_fd);
149 	}
150 }
151 
setup_credentials(void)152 static int setup_credentials(void)
153 {
154 	__maybe_unused int err;
155 
156 #if defined(USE_CERTIFICATE)
157 	err = tls_credential_add(CA_CERTIFICATE_TAG,
158 				TLS_CREDENTIAL_CA_CERTIFICATE,
159 				certificate,
160 				sizeof(certificate));
161 	if (err < 0) {
162 		LOG_ERR("Failed to register certificate: %d", err);
163 		return err;
164 	}
165 #endif
166 
167 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
168 	err = tls_credential_add(PSK_TAG,
169 				TLS_CREDENTIAL_PSK,
170 				psk,
171 				sizeof(psk));
172 	if (err < 0) {
173 		LOG_ERR("Failed to register PSK: %d", err);
174 	}
175 	err = tls_credential_add(PSK_TAG,
176 				TLS_CREDENTIAL_PSK_ID,
177 				psk_id,
178 				sizeof(psk_id) - 1);
179 	if (err < 0) {
180 		LOG_ERR("Failed to register PSK ID: %d", err);
181 	}
182 #endif
183 
184 	return 0;
185 }
186 
main(void)187 int main(void)
188 {
189 	int ret;
190 	size_t data_len;
191 
192 	LOG_INF(APP_BANNER);
193 
194 	setup_credentials();
195 
196 	ret = create_socket();
197 	if (ret < 0) {
198 		LOG_ERR("Socket creation failed (%d)", ret);
199 		goto exit;
200 	}
201 
202 	memcpy(test_buf, TEST_STRING, sizeof(TEST_STRING));
203 	/* The -1 here is because sizeof() accounts for "\0" but that's not
204 	 * needed for socket functions send/recv.
205 	 */
206 	data_len = sizeof(TEST_STRING) - 1;
207 
208 	/* OpenSSL s_server has only the "-rev" option as echo-like behavior
209 	 * which echoes back the data that we send it in reversed order. So
210 	 * in the following we send the test buffer twice (in the 1st iteration
211 	 * it will contain the original TEST_STRING, whereas in the 2nd one
212 	 * it will contain TEST_STRING reversed) so that in the end we can
213 	 * just memcmp() it against the original TEST_STRING.
214 	 */
215 	for (int i = 0; i < 2; i++) {
216 		LOG_DBG("Send: %s", test_buf);
217 		ret = send(socket_fd, test_buf, data_len, 0);
218 		if (ret < 0) {
219 			LOG_ERR("Error sending test string (%d)", errno);
220 			goto exit;
221 		}
222 
223 		memset(test_buf, 0, sizeof(test_buf));
224 
225 		wait_for_event();
226 
227 		ret = recv(socket_fd, test_buf, data_len, MSG_WAITALL);
228 		if (ret == 0) {
229 			LOG_ERR("Server terminated unexpectedly");
230 			ret = -EIO;
231 			goto exit;
232 		} else if (ret < 0) {
233 			LOG_ERR("Error receiving data (%d)", errno);
234 			goto exit;
235 		}
236 		if (ret != data_len) {
237 			LOG_ERR("Sent %d bytes, but received %d", data_len, ret);
238 			ret = -EINVAL;
239 			goto exit;
240 		}
241 		LOG_DBG("Received: %s", test_buf);
242 	}
243 
244 	ret = memcmp(TEST_STRING, test_buf, data_len);
245 	if (ret != 0) {
246 		LOG_ERR("Received data does not match with TEST_STRING");
247 		LOG_HEXDUMP_ERR(test_buf, data_len, "Received:");
248 		LOG_HEXDUMP_ERR(TEST_STRING, data_len, "Expected:");
249 		ret = -EINVAL;
250 	}
251 
252 exit:
253 	LOG_INF("Test %s", (ret < 0) ? "FAILED" : "PASSED");
254 
255 	close_socket();
256 
257 	return 0;
258 }
259