1 /*
2 * Copyright (c) 2024 Nordic Semiconductor ASA
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <zephyr/init.h>
8
9 #include <zephyr/logging/log.h>
10 #include <zephyr/net/socket.h>
11 #include <zephyr/net/socket_service.h>
12 #include <zephyr/net/http/service.h>
13 #include <zephyr/net/http/server.h>
14 #include <zephyr/net/websocket.h>
15 #include <zephyr/shell/shell.h>
16 #include <zephyr/shell/shell_websocket.h>
17 #include <zephyr/logging/log_backend_ws.h>
18
19 LOG_MODULE_REGISTER(shell_websocket, CONFIG_SHELL_WEBSOCKET_INIT_LOG_LEVEL);
20
21 #define WEBSOCKET_LINE_SIZE CONFIG_SHELL_WEBSOCKET_LINE_BUF_SIZE
22 #define WEBSOCKET_TIMEOUT CONFIG_SHELL_WEBSOCKET_SEND_TIMEOUT
23
24 #define WEBSOCKET_MIN_COMMAND_LEN 2
25 #define WEBSOCKET_WILL_DO_COMMAND_LEN 3
26
27 static void ws_server_cb(struct net_socket_service_event *evt);
28
29 NET_SOCKET_SERVICE_SYNC_DEFINE_STATIC(websocket_server, ws_server_cb,
30 SHELL_WEBSOCKET_SERVICE_COUNT);
31
ws_end_client_connection(struct shell_websocket * ws)32 static void ws_end_client_connection(struct shell_websocket *ws)
33 {
34 int ret;
35
36 LOG_DBG("Closing connection to #%d", ws->fds[0].fd);
37
38 (void)log_backend_ws_unregister(ws->fds[0].fd);
39
40 (void)websocket_unregister(ws->fds[0].fd);
41
42 ws->fds[0].fd = -1;
43 ws->output_lock = false;
44
45 k_work_cancel_delayable_sync(&ws->send_work, &ws->work_sync);
46
47 ret = net_socket_service_register(&websocket_server, ws->fds,
48 ARRAY_SIZE(ws->fds), NULL);
49 if (ret < 0) {
50 LOG_ERR("Failed to re-register socket service (%d)", ret);
51 }
52 }
53
ws_send(struct shell_websocket * ws,bool block)54 static int ws_send(struct shell_websocket *ws, bool block)
55 {
56 int ret;
57 uint8_t *msg = ws->line_out.buf;
58 uint16_t len = ws->line_out.len;
59
60 if (ws->line_out.len == 0) {
61 return 0;
62 }
63
64 if (ws->fds[0].fd < 0) {
65 return -ENOTCONN;
66 }
67
68 while (len > 0) {
69 ret = zsock_send(ws->fds[0].fd, msg, len,
70 block ? 0 : ZSOCK_MSG_DONTWAIT);
71 if (!block && (ret < 0) && (errno == EAGAIN)) {
72 /* Not all data was sent - move the remaining data and
73 * update length.
74 */
75 memmove(ws->line_out.buf, msg, len);
76 ws->line_out.len = len;
77 return -EAGAIN;
78 }
79
80 if (ret < 0) {
81 ret = -errno;
82 LOG_ERR("Failed to send %d, shutting down", -ret);
83 ws_end_client_connection(ws);
84 return ret;
85 }
86
87 msg += ret;
88 len -= ret;
89 }
90
91 /* We reinitialize the line buffer */
92 ws->line_out.len = 0;
93
94 return 0;
95 }
96
ws_send_prematurely(struct k_work * work)97 static void ws_send_prematurely(struct k_work *work)
98 {
99 struct k_work_delayable *dwork = k_work_delayable_from_work(work);
100 struct shell_websocket *ws = CONTAINER_OF(dwork,
101 struct shell_websocket,
102 send_work);
103 int ret;
104
105 /* Use non-blocking send to prevent system workqueue blocking. */
106 ret = ws_send(ws, false);
107 if (ret == -EAGAIN) {
108 /* Not all data was sent, reschedule the work. */
109 k_work_reschedule(&ws->send_work, K_MSEC(WEBSOCKET_TIMEOUT));
110 }
111 }
112
ws_recv(struct shell_websocket * ws,struct zsock_pollfd * pollfd)113 static void ws_recv(struct shell_websocket *ws, struct zsock_pollfd *pollfd)
114 {
115 size_t len, buf_left;
116 uint8_t *buf;
117 int ret;
118
119 k_mutex_lock(&ws->rx_lock, K_FOREVER);
120
121 buf_left = sizeof(ws->rx_buf) - ws->rx_len;
122 if (buf_left == 0) {
123 /* No space left to read TCP stream, try again later. */
124 k_mutex_unlock(&ws->rx_lock);
125 k_msleep(10);
126 return;
127 }
128
129 buf = ws->rx_buf + ws->rx_len;
130
131 ret = zsock_recv(pollfd->fd, buf, buf_left, 0);
132 if (ret < 0) {
133 LOG_DBG("Websocket client error %d", ret);
134 goto error;
135 } else if (ret == 0) {
136 LOG_DBG("Websocket client closed connection");
137 goto error;
138 }
139
140 len = ret;
141
142 if (len == 0) {
143 k_mutex_unlock(&ws->rx_lock);
144 return;
145 }
146
147 ws->rx_len += len;
148
149 k_mutex_unlock(&ws->rx_lock);
150
151 ws->shell_handler(SHELL_TRANSPORT_EVT_RX_RDY, ws->shell_context);
152
153 return;
154
155 error:
156 k_mutex_unlock(&ws->rx_lock);
157 ws_end_client_connection(ws);
158 }
159
ws_server_cb(struct net_socket_service_event * evt)160 static void ws_server_cb(struct net_socket_service_event *evt)
161 {
162 socklen_t optlen = sizeof(int);
163 struct shell_websocket *ws;
164 int sock_error;
165
166 ws = (struct shell_websocket *)evt->user_data;
167
168 if ((evt->event.revents & ZSOCK_POLLERR) ||
169 (evt->event.revents & ZSOCK_POLLNVAL)) {
170 (void)zsock_getsockopt(evt->event.fd, SOL_SOCKET,
171 SO_ERROR, &sock_error, &optlen);
172 LOG_ERR("Websocket socket %d error (%d)", evt->event.fd, sock_error);
173
174 if (evt->event.fd == ws->fds[0].fd) {
175 return ws_end_client_connection(ws);
176 }
177
178 return;
179 }
180
181 if (!(evt->event.revents & ZSOCK_POLLIN)) {
182 return;
183 }
184
185 if (evt->event.fd == ws->fds[0].fd) {
186 return ws_recv(ws, &ws->fds[0]);
187 }
188 }
189
shell_ws_init(struct shell_websocket * ctx,int ws_socket)190 static int shell_ws_init(struct shell_websocket *ctx, int ws_socket)
191 {
192 int ret;
193
194 if (ws_socket < 0) {
195 LOG_ERR("Invalid socket %d", ws_socket);
196 return -EBADF;
197 }
198
199 if (ctx->fds[0].fd >= 0) {
200 /* There is already a websocket connection to this shell,
201 * kick the previous connection out.
202 */
203 ws_end_client_connection(ctx);
204 }
205
206 ctx->fds[0].fd = ws_socket;
207 ctx->fds[0].events = ZSOCK_POLLIN;
208
209 ret = net_socket_service_register(&websocket_server, ctx->fds,
210 ARRAY_SIZE(ctx->fds), ctx);
211 if (ret < 0) {
212 LOG_ERR("Failed to register socket service, %d", ret);
213 goto error;
214 }
215
216 log_backend_ws_register(ws_socket);
217
218 return 0;
219
220 error:
221 if (ctx->fds[0].fd >= 0) {
222 (void)zsock_close(ctx->fds[0].fd);
223 ctx->fds[0].fd = -1;
224 }
225
226 return ret;
227 }
228
229 /* Shell API */
230
init(const struct shell_transport * transport,const void * config,shell_transport_handler_t evt_handler,void * context)231 static int init(const struct shell_transport *transport,
232 const void *config,
233 shell_transport_handler_t evt_handler,
234 void *context)
235 {
236 struct shell_websocket *ws;
237
238 ws = (struct shell_websocket *)transport->ctx;
239
240 memset(ws, 0, sizeof(struct shell_websocket));
241 for (int i = 0; i < ARRAY_SIZE(ws->fds); i++) {
242 ws->fds[i].fd = -1;
243 }
244
245 ws->shell_handler = evt_handler;
246 ws->shell_context = context;
247
248 k_work_init_delayable(&ws->send_work, ws_send_prematurely);
249 k_mutex_init(&ws->rx_lock);
250
251 return 0;
252 }
253
uninit(const struct shell_transport * transport)254 static int uninit(const struct shell_transport *transport)
255 {
256 ARG_UNUSED(transport);
257
258 return 0;
259 }
260
enable(const struct shell_transport * transport,bool blocking)261 static int enable(const struct shell_transport *transport, bool blocking)
262 {
263 ARG_UNUSED(transport);
264 ARG_UNUSED(blocking);
265
266 return 0;
267 }
268
sh_write(const struct shell_transport * transport,const void * data,size_t length,size_t * cnt)269 static int sh_write(const struct shell_transport *transport,
270 const void *data, size_t length, size_t *cnt)
271 {
272 struct shell_websocket_line_buf *lb;
273 struct shell_websocket *ws;
274 uint32_t timeout;
275 bool was_running;
276 size_t copy_len;
277 int ret;
278
279 ws = (struct shell_websocket *)transport->ctx;
280
281 if (ws->fds[0].fd < 0 || ws->output_lock) {
282 *cnt = length;
283 return 0;
284 }
285
286 *cnt = 0;
287 lb = &ws->line_out;
288
289 /* Stop the transmission timer, so it does not interrupt the operation.
290 */
291 timeout = k_ticks_to_ms_ceil32(k_work_delayable_remaining_get(&ws->send_work));
292 was_running = k_work_cancel_delayable_sync(&ws->send_work, &ws->work_sync);
293
294 do {
295 if (lb->len + length - *cnt > WEBSOCKET_LINE_SIZE) {
296 copy_len = WEBSOCKET_LINE_SIZE - lb->len;
297 } else {
298 copy_len = length - *cnt;
299 }
300
301 memcpy(lb->buf + lb->len, (uint8_t *)data + *cnt, copy_len);
302 lb->len += copy_len;
303
304 /* Send the data immediately if the buffer is full or line feed
305 * is recognized.
306 */
307 if (lb->buf[lb->len - 1] == '\n' || lb->len == WEBSOCKET_LINE_SIZE) {
308 ret = ws_send(ws, true);
309 if (ret != 0) {
310 *cnt = length;
311 return ret;
312 }
313 }
314
315 *cnt += copy_len;
316 } while (*cnt < length);
317
318 if (lb->len > 0) {
319 /* Check if the timer was already running, initialize otherwise.
320 */
321 timeout = was_running ? timeout : WEBSOCKET_TIMEOUT;
322
323 k_work_reschedule(&ws->send_work, K_MSEC(timeout));
324 }
325
326 ws->shell_handler(SHELL_TRANSPORT_EVT_TX_RDY, ws->shell_context);
327
328 return 0;
329 }
330
sh_read(const struct shell_transport * transport,void * data,size_t length,size_t * cnt)331 static int sh_read(const struct shell_transport *transport,
332 void *data, size_t length, size_t *cnt)
333 {
334 struct shell_websocket *ws;
335 size_t read_len;
336
337 ws = (struct shell_websocket *)transport->ctx;
338
339 if (ws->fds[0].fd < 0) {
340 goto no_data;
341 }
342
343 k_mutex_lock(&ws->rx_lock, K_FOREVER);
344
345 if (ws->rx_len == 0) {
346 k_mutex_unlock(&ws->rx_lock);
347 goto no_data;
348 }
349
350 read_len = ws->rx_len;
351 if (read_len > length) {
352 read_len = length;
353 }
354
355 memcpy(data, ws->rx_buf, read_len);
356 *cnt = read_len;
357
358 ws->rx_len -= read_len;
359 if (ws->rx_len > 0) {
360 memmove(ws->rx_buf, ws->rx_buf + read_len, ws->rx_len);
361 }
362
363 k_mutex_unlock(&ws->rx_lock);
364
365 return 0;
366
367 no_data:
368 *cnt = 0;
369 return 0;
370 }
371
372 const struct shell_transport_api shell_websocket_transport_api = {
373 .init = init,
374 .uninit = uninit,
375 .enable = enable,
376 .write = sh_write,
377 .read = sh_read
378 };
379
shell_websocket_setup(int ws_socket,void * user_data)380 int shell_websocket_setup(int ws_socket, void *user_data)
381 {
382 struct shell_websocket *ws = user_data;
383
384 return shell_ws_init(ws, ws_socket);
385 }
386
shell_websocket_enable(const struct shell * sh)387 int shell_websocket_enable(const struct shell *sh)
388 {
389 bool log_backend = CONFIG_SHELL_WEBSOCKET_INIT_LOG_LEVEL > 0;
390 uint32_t level = (CONFIG_SHELL_WEBSOCKET_INIT_LOG_LEVEL > LOG_LEVEL_DBG) ?
391 CONFIG_LOG_MAX_LEVEL : CONFIG_SHELL_WEBSOCKET_INIT_LOG_LEVEL;
392 static const struct shell_backend_config_flags cfg_flags =
393 SHELL_DEFAULT_BACKEND_CONFIG_FLAGS;
394 int ret;
395
396 ret = shell_init(sh, NULL, cfg_flags, log_backend, level);
397 if (ret < 0) {
398 LOG_DBG("Cannot init websocket shell %p", sh);
399 }
400
401 return ret;
402 }
403