/*
 * Copyright (c) 2024 Måns Ansgariusson <mansgariusson@gmail.com>
 *
 * SPDX-License-Identifier: Apache-2.0
 */
#include <stdint.h>
#include <zephyr/kernel.h>
#include <zephyr/ztest.h>
#include <zephyr/random/random.h>
#include <zephyr/logging/log.h>

LOG_MODULE_REGISTER(k_pipe_concurrency, LOG_LEVEL_DBG);
ZTEST_SUITE(k_pipe_concurrency, NULL, NULL, NULL, NULL, NULL);

static const int partial_wait_time = 2000;
#define DUMMY_DATA_SIZE 16
static struct k_thread thread;
static K_THREAD_STACK_DEFINE(stack, 1024 + CONFIG_TEST_EXTRA_STACK_SIZE);
static struct k_pipe pipe;

static void thread_close(void *arg1, void *arg2, void *arg3)
{
	k_pipe_close((struct k_pipe *)arg1);
}

static void thread_reset(void *arg1, void *arg2, void *arg3)
{
	k_pipe_reset((struct k_pipe *)arg1);
}

static void thread_write(void *arg1, void *arg2, void *arg3)
{
	uint8_t garbage[DUMMY_DATA_SIZE] = {};

	zassert_true(k_pipe_write((struct k_pipe *)arg1, garbage, sizeof(garbage),
		K_MSEC(partial_wait_time)) == sizeof(garbage), "Failed to write to pipe");
}

static void thread_read(void *arg1, void *arg2, void *arg3)
{
	uint8_t garbage[DUMMY_DATA_SIZE];

	zassert_true(k_pipe_read((struct k_pipe *)arg1, garbage, sizeof(garbage),
		K_MSEC(partial_wait_time)) == sizeof(garbage), "Failed to read from pipe");
}

ZTEST(k_pipe_concurrency, test_close_on_read)
{
	k_tid_t tid;
	uint8_t buffer[DUMMY_DATA_SIZE];
	uint8_t res;

	k_pipe_init(&pipe, buffer, sizeof(buffer));
	tid = k_thread_create(&thread, stack, K_THREAD_STACK_SIZEOF(stack),
		thread_close, &pipe, NULL, NULL, K_PRIO_COOP(0), 0, K_MSEC(100));
	zassert_true(tid, "k_thread_create failed");
	zassert_true(k_pipe_read(&pipe, &res, sizeof(res), K_MSEC(1000)) == -EPIPE,
		"Read on closed pipe should return -EPIPE");
	k_thread_join(tid, K_FOREVER);
	zassert_true((pipe.flags & PIPE_FLAG_OPEN) == 0,
		"Pipe should continue to be closed after all waiters have been released");
}

ZTEST(k_pipe_concurrency, test_close_on_write)
{
	k_tid_t tid;
	uint8_t buffer[DUMMY_DATA_SIZE];
	uint8_t garbage[DUMMY_DATA_SIZE];

	k_pipe_init(&pipe, buffer, sizeof(buffer));
	zassert_true(sizeof(garbage) == k_pipe_write(&pipe, garbage, sizeof(garbage), K_MSEC(1000)),
		"Failed to write to pipe");

	tid = k_thread_create(&thread, stack, K_THREAD_STACK_SIZEOF(stack),
		thread_close, &pipe, NULL, NULL, K_PRIO_COOP(0), 0, K_MSEC(100));
	zassert_true(tid, "k_thread_create failed");
	zassert_true(k_pipe_write(&pipe, garbage, sizeof(garbage), K_MSEC(1000)) == -EPIPE,
		"write should return -EPIPE, when pipe is closed");
	k_thread_join(tid, K_FOREVER);
	zassert_true((pipe.flags & PIPE_FLAG_OPEN) == 0,
		"pipe should continue to be closed after all waiters have been released");
}

ZTEST(k_pipe_concurrency, test_reset_on_read)
{
	k_tid_t tid;
	uint8_t buffer[DUMMY_DATA_SIZE];
	uint8_t res;

	k_pipe_init(&pipe, buffer, sizeof(buffer));

	tid = k_thread_create(&thread, stack, K_THREAD_STACK_SIZEOF(stack),
		thread_reset, &pipe, NULL, NULL, K_PRIO_COOP(0), 0, K_MSEC(100));
	zassert_true(tid, "k_thread_create failed");
	zassert_true(k_pipe_read(&pipe, &res, sizeof(res), K_MSEC(1000)) == -ECANCELED,
		"reset on read should return -ECANCELED");
	k_thread_join(tid, K_FOREVER);
	zassert_true((pipe.flags & PIPE_FLAG_RESET) == 0,
		"pipe should not have reset flag after all waiters are done");
	zassert_true((pipe.flags & PIPE_FLAG_OPEN) != 0,
		"pipe should continue to be open after pipe is reseted");
}

ZTEST(k_pipe_concurrency, test_reset_on_write)
{
	k_tid_t tid;
	uint8_t buffer[DUMMY_DATA_SIZE];
	uint8_t garbage[DUMMY_DATA_SIZE];

	k_pipe_init(&pipe, buffer, sizeof(buffer));
	zassert_true(sizeof(garbage) == k_pipe_write(&pipe, garbage, sizeof(garbage), K_MSEC(1000)),
		"Failed to write to pipe");

	tid = k_thread_create(&thread, stack, K_THREAD_STACK_SIZEOF(stack),
		thread_reset, &pipe, NULL, NULL, K_PRIO_COOP(0), 0, K_MSEC(100));
	zassert_true(tid, "k_thread_create failed");
	zassert_true(k_pipe_write(&pipe, garbage, sizeof(garbage), K_MSEC(1000)) == -ECANCELED,
		"reset on write should return -ECANCELED");
	k_thread_join(tid, K_FOREVER);
	zassert_true((pipe.flags & PIPE_FLAG_RESET) == 0,
		"pipe should not have reset flag after all waiters are done");
	zassert_true((pipe.flags & PIPE_FLAG_OPEN) != 0,
		"pipe should continue to be open after pipe is reseted");
}

ZTEST(k_pipe_concurrency, test_partial_read)
{
	k_tid_t tid;
	uint8_t buffer[DUMMY_DATA_SIZE];
	uint8_t garbage[DUMMY_DATA_SIZE];
	size_t write_size = sizeof(garbage)/2;

	k_pipe_init(&pipe, buffer, sizeof(buffer));
	tid = k_thread_create(&thread, stack, K_THREAD_STACK_SIZEOF(stack),
		thread_read, &pipe, NULL, NULL, K_PRIO_COOP(0), 0, K_NO_WAIT);

	zassert_true(k_pipe_write(&pipe, garbage, write_size, K_NO_WAIT) == write_size,
		"write to pipe failed");
	k_msleep(partial_wait_time/4);
	zassert_true(k_pipe_write(&pipe, garbage, write_size, K_NO_WAIT) == write_size,
		"k_k_pipe_write should return number of bytes written");
	k_thread_join(tid, K_FOREVER);
}

ZTEST(k_pipe_concurrency, test_partial_write)
{
	k_tid_t tid;
	uint8_t buffer[DUMMY_DATA_SIZE];
	uint8_t garbage[DUMMY_DATA_SIZE];
	size_t read_size = sizeof(garbage)/2;

	k_pipe_init(&pipe, buffer, sizeof(buffer));

	zassert_true(k_pipe_write(&pipe, garbage, sizeof(garbage), K_NO_WAIT) == sizeof(garbage),
		"Failed to write to pipe");
	tid = k_thread_create(&thread, stack, K_THREAD_STACK_SIZEOF(stack),
		thread_write, &pipe, NULL, NULL, K_PRIO_COOP(0), 0, K_NO_WAIT);

	zassert_true(k_pipe_read(&pipe, garbage, read_size, K_NO_WAIT) == read_size,
		"Failed to read from pipe");
	k_msleep(partial_wait_time/2);
	zassert_true(k_pipe_read(&pipe, garbage, read_size, K_NO_WAIT) == read_size,
		"failed t read from pipe");
	k_thread_join(tid, K_FOREVER);
}

static volatile bool zero_thread_read;
static volatile bool zero_thread_write;
static void zero_thread_read_write(void *arg1, void *arg2, void *arg3)
{
	uint8_t tmp[DUMMY_DATA_SIZE];
	struct k_pipe *input = (struct k_pipe *)arg1;
	struct k_pipe *output = (struct k_pipe *)arg2;

	memset(tmp, 0xBB, sizeof(tmp));

	zero_thread_read = true;
	zassert_true(k_pipe_read(input, tmp, sizeof(tmp), K_FOREVER) == sizeof(tmp),
	      "Failed to read from pipe");
	zero_thread_write = true;
	zassert_true(k_pipe_write(output, tmp, sizeof(tmp), K_FOREVER) == sizeof(tmp),
	      "Failed to write to pipe");
}

ZTEST(k_pipe_concurrency, test_zero_size_pipe_read_write)
{
	k_tid_t tid;
	struct k_pipe input_pipe;
	struct k_pipe output_pipe;
	uint8_t input[DUMMY_DATA_SIZE];
	uint8_t output[DUMMY_DATA_SIZE];

#ifdef CONFIG_KERNEL_COHERENCE
	/* Zero size pipes are not supported due to requiring cache
	 * management on data buffers as the buffers can reside in
	 * incoherent memory. So skip this test.
	 */
	ztest_test_skip();
#endif

	memset(input, 0xAA, sizeof(input));
	memset(output, 0xCC, sizeof(output));
	k_pipe_init(&input_pipe, NULL, 0);
	k_pipe_init(&output_pipe, NULL, 0);

	tid = k_thread_create(&thread, stack, K_THREAD_STACK_SIZEOF(stack),
		zero_thread_read_write, &input_pipe, &output_pipe, NULL, K_PRIO_COOP(0), 0,
		K_NO_WAIT);

	zassert_true(sizeof(input) == k_pipe_write(&input_pipe, input, sizeof(input), K_FOREVER),
	      "Failed to write to pipe");
	zassert_true(sizeof(output) == k_pipe_read(&output_pipe, output, sizeof(output), K_FOREVER),
	      "Failed to read from pipe");
	zassert_true(memcmp(input, output, sizeof(input)) == 0,
		"Unexpected data received from pipe");

	zassert_true(zero_thread_read && zero_thread_write,
		"Thread did not execute expected read/write operations");

	k_thread_join(tid, K_FOREVER);
}
