1 /*
2  * Copyright 2022 Young Mei
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <unistd.h>
8 
9 #include <zephyr/ztest.h>
10 
11 #include <thrift/protocol/TBinaryProtocol.h>
12 #include <thrift/protocol/TCompactProtocol.h>
13 #include <thrift/server/TSimpleServer.h>
14 #include <thrift/transport/TBufferTransports.h>
15 #include <thrift/transport/TFDTransport.h>
16 #include <thrift/transport/TSSLServerSocket.h>
17 #include <thrift/transport/TSSLSocket.h>
18 #include <thrift/transport/TServerSocket.h>
19 
20 #include "context.hpp"
21 #include "server.hpp"
22 #include "thrift/server/TFDServer.h"
23 
24 using namespace apache::thrift;
25 using namespace apache::thrift::protocol;
26 using namespace apache::thrift::transport;
27 
28 ctx context;
29 static K_THREAD_STACK_DEFINE(ThriftTest_server_stack, CONFIG_THRIFTTEST_SERVER_STACK_SIZE);
30 static const char cert_pem[] = {
31 #include "qemu_cert.pem.inc"
32 	'\0'
33 };
34 static const char key_pem[] = {
35 #include "qemu_key.pem.inc"
36 	'\0'
37 };
38 
server_func(void * arg)39 static void *server_func(void *arg)
40 {
41 	(void)arg;
42 
43 	context.server->serve();
44 
45 	return nullptr;
46 }
47 
thrift_test_setup(void)48 static void *thrift_test_setup(void)
49 {
50 	if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) {
51 		TSSLSocketFactory socketFactory;
52 		socketFactory.loadCertificateFromBuffer((const char *)&cert_pem[0]);
53 		socketFactory.loadPrivateKeyFromBuffer((const char *)&key_pem[0]);
54 		socketFactory.loadTrustedCertificatesFromBuffer((const char *)&cert_pem[0]);
55 	}
56 
57 	return NULL;
58 }
59 
setup_client()60 static std::unique_ptr<ThriftTestClient> setup_client()
61 {
62 	std::shared_ptr<TTransport> transport;
63 	std::shared_ptr<TProtocol> protocol;
64 	std::shared_ptr<TTransport> trans(new TFDTransport(context.fds[ctx::CLIENT]));
65 
66 	if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) {
67 		const int port = 4242;
68 		std::shared_ptr<TSSLSocketFactory> socketFactory =
69 			std::make_shared<TSSLSocketFactory>();
70 		socketFactory->authenticate(true);
71 		trans = socketFactory->createSocket(CONFIG_NET_CONFIG_MY_IPV4_ADDR, port);
72 	} else {
73 		trans = std::make_shared<TFDTransport>(context.fds[ctx::CLIENT]);
74 	}
75 
76 	transport = std::make_shared<TBufferedTransport>(trans);
77 
78 	if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) {
79 		protocol = std::make_shared<TCompactProtocol>(transport);
80 	} else {
81 		protocol = std::make_shared<TBinaryProtocol>(transport);
82 	}
83 	transport->open();
84 	return std::unique_ptr<ThriftTestClient>(new ThriftTestClient(protocol));
85 }
86 
setup_server()87 static std::unique_ptr<TServer> setup_server()
88 {
89 	std::shared_ptr<TestHandler> handler(new TestHandler());
90 	std::shared_ptr<TProcessor> processor(new ThriftTestProcessor(handler));
91 	std::shared_ptr<TServerTransport> serverTransport;
92 	std::shared_ptr<TProtocolFactory> protocolFactory;
93 	std::shared_ptr<TTransportFactory> transportFactory;
94 
95 	if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) {
96 		const int port = 4242;
97 		std::shared_ptr<TSSLSocketFactory> socketFactory(new TSSLSocketFactory());
98 		socketFactory->server(true);
99 		serverTransport =
100 			std::make_shared<TSSLServerSocket>("0.0.0.0", port, socketFactory);
101 	} else {
102 		serverTransport = std::make_shared<TFDServer>(context.fds[ctx::SERVER]);
103 	}
104 
105 	transportFactory = std::make_shared<TBufferedTransportFactory>();
106 
107 	if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) {
108 		protocolFactory = std::make_shared<TCompactProtocolFactory>();
109 	} else {
110 		protocolFactory = std::make_shared<TBinaryProtocolFactory>();
111 	}
112 	TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
113 	return std::unique_ptr<TServer>(
114 		new TSimpleServer(processor, serverTransport, transportFactory, protocolFactory));
115 }
116 
thrift_test_before(void * data)117 static void thrift_test_before(void *data)
118 {
119 	ARG_UNUSED(data);
120 	int rv;
121 
122 	pthread_attr_t attr;
123 	pthread_attr_t *attrp = &attr;
124 
125 	if (IS_ENABLED(CONFIG_ARCH_POSIX)) {
126 		attrp = NULL;
127 	} else {
128 		rv = pthread_attr_init(attrp);
129 		zassert_equal(0, rv, "pthread_attr_init failed: %d", rv);
130 		rv = pthread_attr_setstack(attrp, ThriftTest_server_stack,
131 					   CONFIG_THRIFTTEST_SERVER_STACK_SIZE);
132 		zassert_equal(0, rv, "pthread_attr_setstack failed: %d", rv);
133 	}
134 
135 	// create the communication channel
136 	rv = socketpair(AF_UNIX, SOCK_STREAM, 0, &context.fds.front());
137 	zassert_equal(0, rv, "socketpair failed: %d\n", rv);
138 
139 	// set up server
140 	context.server = setup_server();
141 
142 	// start the server
143 	rv = pthread_create(&context.server_thread, attrp, server_func, nullptr);
144 	zassert_equal(0, rv, "pthread_create failed: %d", rv);
145 
146 	/* Give the server thread a chance to start and prepare the socket */
147 	k_msleep(50);
148 
149 	// set up client
150 	context.client = setup_client();
151 }
152 
thrift_test_after(void * data)153 static void thrift_test_after(void *data)
154 {
155 	ARG_UNUSED(data);
156 
157 	context.server->stop();
158 
159 	pthread_join(context.server_thread, NULL);
160 
161 	for (auto &fd : context.fds) {
162 		close(fd);
163 		fd = -1;
164 	}
165 
166 	context.client.reset();
167 	context.server.reset();
168 
169 	k_msleep(CONFIG_NET_TCP_TIME_WAIT_DELAY);
170 }
171 
172 ZTEST_SUITE(thrift, NULL, thrift_test_setup, thrift_test_before, thrift_test_after, NULL);
173