1 /*
2 * Copyright 2022 Meta
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <array>
8 #include <system_error>
9
10 #include <errno.h>
11 #include <zephyr/posix/poll.h>
12 #include <zephyr/posix/sys/eventfd.h>
13 #include <zephyr/posix/unistd.h>
14
15 #include <thrift/transport/TFDTransport.h>
16
17 #include <zephyr/kernel.h>
18 #include <zephyr/logging/log.h>
19
20 #include "thrift/server/TFDServer.h"
21
22 LOG_MODULE_REGISTER(TFDServer, LOG_LEVEL_INF);
23
24 using namespace std;
25
26 namespace apache
27 {
28 namespace thrift
29 {
30 namespace transport
31 {
32
33 class xport : public TVirtualTransport<xport>
34 {
35 public:
xport(int fd)36 xport(int fd) : xport(fd, eventfd(0, EFD_SEMAPHORE))
37 {
38 }
xport(int fd,int efd)39 xport(int fd, int efd) : fd(fd), efd(efd)
40 {
41 __ASSERT(fd >= 0, "invalid fd %d", fd);
42 __ASSERT(efd >= 0, "invalid efd %d", efd);
43
44 LOG_DBG("created xport with fd %d and efd %d", fd, efd);
45 }
46
~xport()47 ~xport()
48 {
49 close();
50 }
51
read_virt(uint8_t * buf,uint32_t len)52 virtual uint32_t read_virt(uint8_t *buf, uint32_t len) override
53 {
54 int r;
55 array<pollfd, 2> pollfds = {
56 (pollfd){
57 .fd = fd,
58 .events = POLLIN,
59 .revents = 0,
60 },
61 (pollfd){
62 .fd = efd,
63 .events = POLLIN,
64 .revents = 0,
65 },
66 };
67
68 if (!isOpen()) {
69 return 0;
70 }
71
72 r = poll(&pollfds.front(), pollfds.size(), -1);
73 if (r == -1) {
74 if (efd == -1 || fd == -1) {
75 /* channel has been closed */
76 return 0;
77 }
78
79 LOG_ERR("failed to poll fds %d, %d: %d", fd, efd, errno);
80 throw system_error(errno, system_category(), "poll");
81 }
82
83 for (auto &pfd : pollfds) {
84 if (pfd.revents & POLLNVAL) {
85 LOG_DBG("fd %d is invalid", pfd.fd);
86 return 0;
87 }
88 }
89
90 if (pollfds[0].revents & POLLIN) {
91 r = ::read(fd, buf, len);
92 if (r == -1) {
93 LOG_ERR("failed to read %d bytes from fd %d: %d", len, fd, errno);
94 system_error(errno, system_category(), "read");
95 }
96
97 __ASSERT_NO_MSG(r > 0);
98
99 return uint32_t(r);
100 }
101
102 __ASSERT_NO_MSG(pollfds[1].revents & POLLIN);
103
104 return 0;
105 }
106
write_virt(const uint8_t * buf,uint32_t len)107 virtual void write_virt(const uint8_t *buf, uint32_t len) override
108 {
109
110 if (!isOpen()) {
111 throw TTransportException(TTransportException::END_OF_FILE);
112 }
113
114 for (int r = 0; len > 0; buf += r, len -= r) {
115 r = ::write(fd, buf, len);
116 if (r == -1) {
117 LOG_ERR("writing %u bytes to fd %d failed: %d", len, fd, errno);
118 throw system_error(errno, system_category(), "write");
119 }
120
121 __ASSERT_NO_MSG(r > 0);
122 }
123 }
124
interrupt()125 void interrupt()
126 {
127 if (!isOpen()) {
128 return;
129 }
130
131 constexpr uint64_t x = 0xb7e;
132 int r = ::write(efd, &x, sizeof(x));
133 if (r == -1) {
134 LOG_ERR("writing %zu bytes to fd %d failed: %d", sizeof(x), efd, errno);
135 throw system_error(errno, system_category(), "write");
136 }
137
138 __ASSERT_NO_MSG(r > 0);
139
140 LOG_DBG("interrupted xport with fd %d and efd %d", fd, efd);
141
142 // there is no interrupt() method in the parent class, but the intent of
143 // interrupt() is to prevent future communication on this transport. The
144 // most reliable way we have of doing this is to close it :-)
145 close();
146 }
147
close()148 void close() override
149 {
150 if (isOpen()) {
151 ::close(efd);
152 LOG_DBG("closed xport with fd %d and efd %d", fd, efd);
153
154 efd = -1;
155 // we only have a copy of fd and do not own it
156 fd = -1;
157 }
158 }
159
isOpen() const160 bool isOpen() const override
161 {
162 return fd >= 0 && efd >= 0;
163 }
164
165 protected:
166 int fd;
167 int efd;
168 };
169
TFDServer(int fd)170 TFDServer::TFDServer(int fd) : fd(fd)
171 {
172 }
173
~TFDServer()174 TFDServer::~TFDServer()
175 {
176 interruptChildren();
177 interrupt();
178 }
179
isOpen() const180 bool TFDServer::isOpen() const
181 {
182 return fd >= 0;
183 }
184
acceptImpl()185 shared_ptr<TTransport> TFDServer::acceptImpl()
186 {
187 if (!isOpen()) {
188 throw TTransportException(TTransportException::INTERRUPTED);
189 }
190
191 children.push_back(shared_ptr<TTransport>(new xport(fd)));
192
193 return children.back();
194 }
195
getSocketFD()196 THRIFT_SOCKET TFDServer::getSocketFD()
197 {
198 return fd;
199 }
200
close()201 void TFDServer::close()
202 {
203 // we only have a copy of fd and do not own it
204 fd = -1;
205 }
206
interrupt()207 void TFDServer::interrupt()
208 {
209 close();
210 }
211
interruptChildren()212 void TFDServer::interruptChildren()
213 {
214 for (auto c : children) {
215 auto child = reinterpret_cast<xport *>(c.get());
216 child->interrupt();
217 }
218
219 children.clear();
220 }
221 } // namespace transport
222 } // namespace thrift
223 } // namespace apache
224