1 /*
2  * Copyright 2022 Meta
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <array>
8 #include <system_error>
9 
10 #include <errno.h>
11 #include <zephyr/posix/poll.h>
12 #include <zephyr/posix/sys/eventfd.h>
13 #include <zephyr/posix/unistd.h>
14 
15 #include <thrift/transport/TFDTransport.h>
16 
17 #include <zephyr/kernel.h>
18 #include <zephyr/logging/log.h>
19 
20 #include "thrift/server/TFDServer.h"
21 
22 LOG_MODULE_REGISTER(TFDServer, LOG_LEVEL_INF);
23 
24 using namespace std;
25 
26 namespace apache
27 {
28 namespace thrift
29 {
30 namespace transport
31 {
32 
33 class xport : public TVirtualTransport<xport>
34 {
35       public:
xport(int fd)36 	xport(int fd) : xport(fd, eventfd(0, EFD_SEMAPHORE))
37 	{
38 	}
xport(int fd,int efd)39 	xport(int fd, int efd) : fd(fd), efd(efd)
40 	{
41 		__ASSERT(fd >= 0, "invalid fd %d", fd);
42 		__ASSERT(efd >= 0, "invalid efd %d", efd);
43 
44 		LOG_DBG("created xport with fd %d and efd %d", fd, efd);
45 	}
46 
~xport()47 	~xport()
48 	{
49 		close();
50 	}
51 
read_virt(uint8_t * buf,uint32_t len)52 	virtual uint32_t read_virt(uint8_t *buf, uint32_t len) override
53 	{
54 		int r;
55 		array<pollfd, 2> pollfds = {
56 			(pollfd){
57 				.fd = fd,
58 				.events = POLLIN,
59 				.revents = 0,
60 			},
61 			(pollfd){
62 				.fd = efd,
63 				.events = POLLIN,
64 				.revents = 0,
65 			},
66 		};
67 
68 		if (!isOpen()) {
69 			return 0;
70 		}
71 
72 		r = poll(&pollfds.front(), pollfds.size(), -1);
73 		if (r == -1) {
74 			if (efd == -1 || fd == -1) {
75 				/* channel has been closed */
76 				return 0;
77 			}
78 
79 			LOG_ERR("failed to poll fds %d, %d: %d", fd, efd, errno);
80 			throw system_error(errno, system_category(), "poll");
81 		}
82 
83 		for (auto &pfd : pollfds) {
84 			if (pfd.revents & POLLNVAL) {
85 				LOG_DBG("fd %d is invalid", pfd.fd);
86 				return 0;
87 			}
88 		}
89 
90 		if (pollfds[0].revents & POLLIN) {
91 			r = ::read(fd, buf, len);
92 			if (r == -1) {
93 				LOG_ERR("failed to read %d bytes from fd %d: %d", len, fd, errno);
94 				system_error(errno, system_category(), "read");
95 			}
96 
97 			__ASSERT_NO_MSG(r > 0);
98 
99 			return uint32_t(r);
100 		}
101 
102 		__ASSERT_NO_MSG(pollfds[1].revents & POLLIN);
103 
104 		return 0;
105 	}
106 
write_virt(const uint8_t * buf,uint32_t len)107 	virtual void write_virt(const uint8_t *buf, uint32_t len) override
108 	{
109 
110 		if (!isOpen()) {
111 			throw TTransportException(TTransportException::END_OF_FILE);
112 		}
113 
114 		for (int r = 0; len > 0; buf += r, len -= r) {
115 			r = ::write(fd, buf, len);
116 			if (r == -1) {
117 				LOG_ERR("writing %u bytes to fd %d failed: %d", len, fd, errno);
118 				throw system_error(errno, system_category(), "write");
119 			}
120 
121 			__ASSERT_NO_MSG(r > 0);
122 		}
123 	}
124 
interrupt()125 	void interrupt()
126 	{
127 		if (!isOpen()) {
128 			return;
129 		}
130 
131 		constexpr uint64_t x = 0xb7e;
132 		int r = ::write(efd, &x, sizeof(x));
133 		if (r == -1) {
134 			LOG_ERR("writing %zu bytes to fd %d failed: %d", sizeof(x), efd, errno);
135 			throw system_error(errno, system_category(), "write");
136 		}
137 
138 		__ASSERT_NO_MSG(r > 0);
139 
140 		LOG_DBG("interrupted xport with fd %d and efd %d", fd, efd);
141 
142 		// there is no interrupt() method in the parent class, but the intent of
143 		// interrupt() is to prevent future communication on this transport. The
144 		// most reliable way we have of doing this is to close it :-)
145 		close();
146 	}
147 
close()148 	void close() override
149 	{
150 		if (isOpen()) {
151 			::close(efd);
152 			LOG_DBG("closed xport with fd %d and efd %d", fd, efd);
153 
154 			efd = -1;
155 			// we only have a copy of fd and do not own it
156 			fd = -1;
157 		}
158 	}
159 
isOpen() const160 	bool isOpen() const override
161 	{
162 		return fd >= 0 && efd >= 0;
163 	}
164 
165       protected:
166 	int fd;
167 	int efd;
168 };
169 
TFDServer(int fd)170 TFDServer::TFDServer(int fd) : fd(fd)
171 {
172 }
173 
~TFDServer()174 TFDServer::~TFDServer()
175 {
176 	interruptChildren();
177 	interrupt();
178 }
179 
isOpen() const180 bool TFDServer::isOpen() const
181 {
182 	return fd >= 0;
183 }
184 
acceptImpl()185 shared_ptr<TTransport> TFDServer::acceptImpl()
186 {
187 	if (!isOpen()) {
188 		throw TTransportException(TTransportException::INTERRUPTED);
189 	}
190 
191 	children.push_back(shared_ptr<TTransport>(new xport(fd)));
192 
193 	return children.back();
194 }
195 
getSocketFD()196 THRIFT_SOCKET TFDServer::getSocketFD()
197 {
198 	return fd;
199 }
200 
close()201 void TFDServer::close()
202 {
203 	// we only have a copy of fd and do not own it
204 	fd = -1;
205 }
206 
interrupt()207 void TFDServer::interrupt()
208 {
209 	close();
210 }
211 
interruptChildren()212 void TFDServer::interruptChildren()
213 {
214 	for (auto c : children) {
215 		auto child = reinterpret_cast<xport *>(c.get());
216 		child->interrupt();
217 	}
218 
219 	children.clear();
220 }
221 } // namespace transport
222 } // namespace thrift
223 } // namespace apache
224