1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  * Copyright (c) 2024 Nordic Semiconductor ASA
4  */
5 
6 #include <zephyr/logging/log.h>
7 LOG_MODULE_REGISTER(tls_configuration_sample, LOG_LEVEL_INF);
8 
9 #include <zephyr/kernel.h>
10 #include <errno.h>
11 #include <stdio.h>
12 
13 #include <zephyr/posix/sys/eventfd.h>
14 
15 #include <zephyr/net/socket.h>
16 #include <zephyr/net/tls_credentials.h>
17 #include <zephyr/net/net_if.h>
18 #include <zephyr/sys/util.h>
19 
20 /* This include is required for the definition of the Mbed TLS internal symbol
21  * MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED.
22  */
23 #include <mbedtls/ssl_ciphersuites.h>
24 
25 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
26 static const unsigned char psk[] = { 0x01, 0x02, 0x03, 0x04, 0x05 };
27 static const char psk_id[] = "PSK_identity";
28 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED */
29 
30 /* Following certificates (*.inc files) are:
31  * - generated from "create-certs.sh" script
32  * - converted in C array shape in the CMakeList file
33  */
34 #if defined(CONFIG_PSA_WANT_ALG_RSA_PKCS1V15_SIGN) || defined(CONFIG_PSA_WANT_ALG_RSA_PSS)
35 #define USE_CERTIFICATE
36 static const unsigned char certificate[] = {
37 #include "rsa.crt.inc"
38 };
39 #elif defined(CONFIG_PSA_WANT_ALG_ECDSA)
40 #define USE_CERTIFICATE
41 static const unsigned char certificate[] = {
42 #include "ec.crt.inc"
43 };
44 #endif
45 
46 #define APP_BANNER "TLS socket configuration sample"
47 
48 #define INVALID_SOCKET (-1)
49 
50 enum {
51 	_PLACEHOLDER_TAG_ = 0,
52 #if defined(USE_CERTIFICATE)
53 	CA_CERTIFICATE_TAG,
54 #endif
55 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
56 	PSK_TAG,
57 #endif
58 };
59 
60 static int socket_fd = INVALID_SOCKET;
61 static struct pollfd fds[1];
62 
63 /* Keep the new line because openssl uses that to start processing the incoming data */
64 #define TEST_STRING "hello world\n"
65 static uint8_t test_buf[sizeof(TEST_STRING)];
66 
wait_for_event(void)67 static int wait_for_event(void)
68 {
69 	int ret;
70 
71 	/* Wait for event on any socket used. Once event occurs,
72 	 * we'll check them all.
73 	 */
74 	ret = poll(fds, ARRAY_SIZE(fds), -1);
75 	if (ret < 0) {
76 		LOG_ERR("Error in poll (%d)", errno);
77 		return ret;
78 	}
79 
80 	return 0;
81 }
82 
create_socket(void)83 static int create_socket(void)
84 {
85 	int ret = 0;
86 	struct sockaddr_in addr;
87 
88 	addr.sin_family = AF_INET;
89 	addr.sin_port = htons(CONFIG_SERVER_PORT);
90 	inet_pton(AF_INET, "127.0.0.1", &addr.sin_addr);
91 
92 #if defined(CONFIG_MBEDTLS_TLS_VERSION_1_3)
93 	socket_fd = socket(addr.sin_family, SOCK_STREAM, IPPROTO_TLS_1_3);
94 #else
95 	socket_fd = socket(addr.sin_family, SOCK_STREAM, IPPROTO_TLS_1_2);
96 #endif
97 	if (socket_fd < 0) {
98 		LOG_ERR("Failed to create TLS socket (%d)", errno);
99 		return -errno;
100 	}
101 
102 	sec_tag_t sec_tag_list[] = {
103 #if defined(USE_CERTIFICATE)
104 		CA_CERTIFICATE_TAG,
105 #endif
106 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
107 		PSK_TAG,
108 #endif
109 	};
110 
111 	ret = setsockopt(socket_fd, SOL_TLS, TLS_SEC_TAG_LIST,
112 			sec_tag_list, sizeof(sec_tag_list));
113 	if (ret < 0) {
114 		LOG_ERR("Failed to set TLS_SEC_TAG_LIST option (%d)", errno);
115 		return -errno;
116 	}
117 
118 	/* HOSTNAME is only required for key exchanges that use a certificate. */
119 #if defined(USE_CERTIFICATE)
120 	ret = setsockopt(socket_fd, SOL_TLS, TLS_HOSTNAME,
121 			 "localhost", sizeof("localhost"));
122 	if (ret < 0) {
123 		LOG_ERR("Failed to set TLS_HOSTNAME option (%d)", errno);
124 		return -errno;
125 	}
126 #endif
127 
128 	ret = connect(socket_fd, (struct sockaddr *) &addr, sizeof(addr));
129 	if (ret < 0) {
130 		LOG_ERR("Cannot connect to TCP remote (%d)", errno);
131 		return -errno;
132 	}
133 
134 	/* Prepare file descriptor for polling */
135 	fds[0].fd = socket_fd;
136 	fds[0].events = POLLIN;
137 
138 	return ret;
139 }
140 
close_socket(void)141 void close_socket(void)
142 {
143 	if (socket_fd != INVALID_SOCKET) {
144 		close(socket_fd);
145 	}
146 }
147 
setup_credentials(void)148 static int setup_credentials(void)
149 {
150 	int err;
151 
152 #if defined(USE_CERTIFICATE)
153 	err = tls_credential_add(CA_CERTIFICATE_TAG,
154 				TLS_CREDENTIAL_CA_CERTIFICATE,
155 				certificate,
156 				sizeof(certificate));
157 	if (err < 0) {
158 		LOG_ERR("Failed to register certificate: %d", err);
159 		return err;
160 	}
161 #endif
162 
163 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
164 	err = tls_credential_add(PSK_TAG,
165 				TLS_CREDENTIAL_PSK,
166 				psk,
167 				sizeof(psk));
168 	if (err < 0) {
169 		LOG_ERR("Failed to register PSK: %d", err);
170 	}
171 	err = tls_credential_add(PSK_TAG,
172 				TLS_CREDENTIAL_PSK_ID,
173 				psk_id,
174 				sizeof(psk_id) - 1);
175 	if (err < 0) {
176 		LOG_ERR("Failed to register PSK ID: %d", err);
177 	}
178 #endif
179 
180 	return 0;
181 }
182 
main(void)183 int main(void)
184 {
185 	int ret;
186 	size_t data_len;
187 
188 	LOG_INF(APP_BANNER);
189 
190 	setup_credentials();
191 
192 	ret = create_socket();
193 	if (ret < 0) {
194 		LOG_ERR("Socket creation failed (%d)", ret);
195 		goto exit;
196 	}
197 
198 	memcpy(test_buf, TEST_STRING, sizeof(TEST_STRING));
199 	/* The -1 here is because sizeof() accounts for "\0" but that's not
200 	 * needed for socket functions send/recv.
201 	 */
202 	data_len = sizeof(TEST_STRING) - 1;
203 
204 	/* OpenSSL s_server has only the "-rev" option as echo-like behavior
205 	 * which echoes back the data that we send it in reversed order. So
206 	 * in the following we send the test buffer twice (in the 1st iteration
207 	 * it will contain the original TEST_STRING, whereas in the 2nd one
208 	 * it will contain TEST_STRING reversed) so that in the end we can
209 	 * just memcmp() it against the original TEST_STRING.
210 	 */
211 	for (int i = 0; i < 2; i++) {
212 		LOG_DBG("Send: %s", test_buf);
213 		ret = send(socket_fd, test_buf, data_len, 0);
214 		if (ret < 0) {
215 			LOG_ERR("Error sending test string (%d)", errno);
216 			goto exit;
217 		}
218 
219 		memset(test_buf, 0, sizeof(test_buf));
220 
221 		wait_for_event();
222 
223 		ret = recv(socket_fd, test_buf, data_len, MSG_WAITALL);
224 		if (ret == 0) {
225 			LOG_ERR("Server terminated unexpectedly");
226 			ret = -EIO;
227 			goto exit;
228 		} else if (ret < 0) {
229 			LOG_ERR("Error receiving data (%d)", errno);
230 			goto exit;
231 		}
232 		if (ret != data_len) {
233 			LOG_ERR("Sent %d bytes, but received %d", data_len, ret);
234 			ret = -EINVAL;
235 			goto exit;
236 		}
237 		LOG_DBG("Received: %s", test_buf);
238 	}
239 
240 	ret = memcmp(TEST_STRING, test_buf, data_len);
241 	if (ret != 0) {
242 		LOG_ERR("Received data does not match with TEST_STRING");
243 		LOG_HEXDUMP_ERR(test_buf, data_len, "Received:");
244 		LOG_HEXDUMP_ERR(TEST_STRING, data_len, "Expected:");
245 		ret = -EINVAL;
246 	}
247 
248 exit:
249 	LOG_INF("Test %s", (ret < 0) ? "FAILED" : "PASSED");
250 
251 	close_socket();
252 
253 	return 0;
254 }
255