1 /*
2 * Copyright 2022 Young Mei
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <unistd.h>
8
9 #include <zephyr/ztest.h>
10
11 #include <thrift/protocol/TBinaryProtocol.h>
12 #include <thrift/protocol/TCompactProtocol.h>
13 #include <thrift/server/TSimpleServer.h>
14 #include <thrift/transport/TBufferTransports.h>
15 #include <thrift/transport/TFDTransport.h>
16 #include <thrift/transport/TSSLServerSocket.h>
17 #include <thrift/transport/TSSLSocket.h>
18 #include <thrift/transport/TServerSocket.h>
19
20 #include "context.hpp"
21 #include "server.hpp"
22 #include "thrift/server/TFDServer.h"
23
24 using namespace apache::thrift;
25 using namespace apache::thrift::protocol;
26 using namespace apache::thrift::transport;
27
28 ctx context;
29 static K_THREAD_STACK_DEFINE(ThriftTest_server_stack, CONFIG_THRIFTTEST_SERVER_STACK_SIZE);
30 static const char cert_pem[] = {
31 #include "qemu_cert.pem.inc"
32 '\0'
33 };
34 static const char key_pem[] = {
35 #include "qemu_key.pem.inc"
36 '\0'
37 };
38
server_func(void * arg)39 static void *server_func(void *arg)
40 {
41 (void)arg;
42
43 context.server->serve();
44
45 return nullptr;
46 }
47
thrift_test_setup(void)48 static void *thrift_test_setup(void)
49 {
50 if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) {
51 TSSLSocketFactory socketFactory;
52 socketFactory.loadCertificateFromBuffer((const char *)&cert_pem[0]);
53 socketFactory.loadPrivateKeyFromBuffer((const char *)&key_pem[0]);
54 socketFactory.loadTrustedCertificatesFromBuffer((const char *)&cert_pem[0]);
55 }
56
57 return NULL;
58 }
59
setup_client()60 static std::unique_ptr<ThriftTestClient> setup_client()
61 {
62 std::shared_ptr<TTransport> transport;
63 std::shared_ptr<TProtocol> protocol;
64 std::shared_ptr<TTransport> trans(new TFDTransport(context.fds[ctx::CLIENT]));
65
66 if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) {
67 const int port = 4242;
68 std::shared_ptr<TSSLSocketFactory> socketFactory =
69 std::make_shared<TSSLSocketFactory>();
70 socketFactory->authenticate(true);
71 trans = socketFactory->createSocket(CONFIG_NET_CONFIG_MY_IPV4_ADDR, port);
72 } else {
73 trans = std::make_shared<TFDTransport>(context.fds[ctx::CLIENT]);
74 }
75
76 transport = std::make_shared<TBufferedTransport>(trans);
77
78 if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) {
79 protocol = std::make_shared<TCompactProtocol>(transport);
80 } else {
81 protocol = std::make_shared<TBinaryProtocol>(transport);
82 }
83 transport->open();
84 return std::unique_ptr<ThriftTestClient>(new ThriftTestClient(protocol));
85 }
86
setup_server()87 static std::unique_ptr<TServer> setup_server()
88 {
89 std::shared_ptr<TestHandler> handler(new TestHandler());
90 std::shared_ptr<TProcessor> processor(new ThriftTestProcessor(handler));
91 std::shared_ptr<TServerTransport> serverTransport;
92 std::shared_ptr<TProtocolFactory> protocolFactory;
93 std::shared_ptr<TTransportFactory> transportFactory;
94
95 if (IS_ENABLED(CONFIG_THRIFT_SSL_SOCKET)) {
96 const int port = 4242;
97 std::shared_ptr<TSSLSocketFactory> socketFactory(new TSSLSocketFactory());
98 socketFactory->server(true);
99 serverTransport =
100 std::make_shared<TSSLServerSocket>("0.0.0.0", port, socketFactory);
101 } else {
102 serverTransport = std::make_shared<TFDServer>(context.fds[ctx::SERVER]);
103 }
104
105 transportFactory = std::make_shared<TBufferedTransportFactory>();
106
107 if (IS_ENABLED(CONFIG_THRIFT_COMPACT_PROTOCOL)) {
108 protocolFactory = std::make_shared<TCompactProtocolFactory>();
109 } else {
110 protocolFactory = std::make_shared<TBinaryProtocolFactory>();
111 }
112 TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);
113 return std::unique_ptr<TServer>(
114 new TSimpleServer(processor, serverTransport, transportFactory, protocolFactory));
115 }
116
thrift_test_before(void * data)117 static void thrift_test_before(void *data)
118 {
119 ARG_UNUSED(data);
120 int rv;
121
122 pthread_attr_t attr;
123 pthread_attr_t *attrp = &attr;
124
125 if (IS_ENABLED(CONFIG_ARCH_POSIX)) {
126 attrp = NULL;
127 } else {
128 rv = pthread_attr_init(attrp);
129 zassert_equal(0, rv, "pthread_attr_init failed: %d", rv);
130 rv = pthread_attr_setstack(attrp, ThriftTest_server_stack,
131 CONFIG_THRIFTTEST_SERVER_STACK_SIZE);
132 zassert_equal(0, rv, "pthread_attr_setstack failed: %d", rv);
133 }
134
135 // create the communication channel
136 rv = socketpair(AF_UNIX, SOCK_STREAM, 0, &context.fds.front());
137 zassert_equal(0, rv, "socketpair failed: %d\n", rv);
138
139 // set up server
140 context.server = setup_server();
141
142 // start the server
143 rv = pthread_create(&context.server_thread, attrp, server_func, nullptr);
144 zassert_equal(0, rv, "pthread_create failed: %d", rv);
145
146 /* Give the server thread a chance to start and prepare the socket */
147 k_msleep(50);
148
149 // set up client
150 context.client = setup_client();
151 }
152
thrift_test_after(void * data)153 static void thrift_test_after(void *data)
154 {
155 ARG_UNUSED(data);
156
157 context.server->stop();
158
159 pthread_join(context.server_thread, NULL);
160
161 for (auto &fd : context.fds) {
162 close(fd);
163 fd = -1;
164 }
165
166 context.client.reset();
167 context.server.reset();
168
169 k_msleep(CONFIG_NET_TCP_TIME_WAIT_DELAY);
170 }
171
172 ZTEST_SUITE(thrift, NULL, thrift_test_setup, thrift_test_before, thrift_test_after, NULL);
173