1 /*
2  * Copyright (c) 2024 Meta Platforms
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <stdio.h>
8 
9 #include <zephyr/arch/riscv/reg.h>
10 #include <zephyr/kernel.h>
11 #include <zephyr/ztest.h>
12 
13 #if !defined(CONFIG_RISCV_GP) && !defined(CONFIG_RISCV_CURRENT_VIA_GP)
14 #error "CONFIG_RISCV_GP or CONFIG_RISCV_CURRENT_VIA_GP must be enabled for this test"
15 #endif
16 
17 #define ROGUE_USER_STACK_SZ 2048
18 
19 static struct k_thread rogue_user_thread;
20 static K_THREAD_STACK_DEFINE(rogue_user_stack, ROGUE_USER_STACK_SZ);
21 
rogue_user_fn(void * p1,void * p2,void * p3)22 static void rogue_user_fn(void *p1, void *p2, void *p3)
23 {
24 	zassert_true(k_is_user_context());
25 	uintptr_t gp_val = reg_read(gp);
26 	uintptr_t gp_test_val;
27 
28 	/* Make sure that `gp` is as expected */
29 	if (IS_ENABLED(CONFIG_RISCV_GP)) {
30 		__asm__ volatile("la %0, __global_pointer$" : "=r" (gp_test_val));
31 	} else { /* CONFIG_RISCV_CURRENT_VIA_GP */
32 		gp_test_val = (uintptr_t)k_current_get();
33 	}
34 
35 	/* Corrupt `gp` reg */
36 	reg_write(gp, 0xbad);
37 
38 	/* Make sure that `gp` is corrupted */
39 	if (IS_ENABLED(CONFIG_RISCV_GP)) {
40 		zassert_equal(reg_read(gp), 0xbad);
41 	} else { /* CONFIG_RISCV_CURRENT_VIA_GP */
42 		zassert_equal((uintptr_t)_current, 0xbad);
43 	}
44 
45 	/* Sleep to force a context switch, which will sanitize `gp` */
46 	k_msleep(50);
47 
48 	/* Make sure that `gp` is sane again */
49 	if (IS_ENABLED(CONFIG_RISCV_GP)) {
50 		__asm__ volatile("la %0, __global_pointer$" : "=r" (gp_test_val));
51 	} else { /* CONFIG_RISCV_CURRENT_VIA_GP */
52 		gp_test_val = (uintptr_t)k_current_get();
53 	}
54 
55 	zassert_equal(gp_val, gp_test_val);
56 }
57 
ZTEST_USER(riscv_gp,test_gp_value)58 ZTEST_USER(riscv_gp, test_gp_value)
59 {
60 	uintptr_t gp_val = reg_read(gp);
61 	uintptr_t gp_test_val;
62 	k_tid_t th;
63 
64 	if (IS_ENABLED(CONFIG_RISCV_GP)) {
65 		__asm__ volatile("la %0, __global_pointer$" : "=r" (gp_test_val));
66 	} else { /* CONFIG_RISCV_CURRENT_VIA_GP */
67 		gp_test_val = (uintptr_t)k_current_get();
68 	}
69 	zassert_equal(gp_val, gp_test_val);
70 
71 	/* Create and run a rogue thread to corrupt the `gp` */
72 	th = k_thread_create(&rogue_user_thread, rogue_user_stack, ROGUE_USER_STACK_SZ,
73 			     rogue_user_fn, NULL, NULL, NULL, -1, K_USER, K_NO_WAIT);
74 	zassert_ok(k_thread_join(th, K_FOREVER));
75 
76 	/* Make sure that `gp` is the same as before a rogue thread was executed */
77 	zassert_equal(reg_read(gp), gp_val, "`gp` corrupted by user thread");
78 }
79 
userspace_setup(void)80 static void *userspace_setup(void)
81 {
82 	k_thread_access_grant(k_current_get(), &rogue_user_thread, &rogue_user_stack);
83 
84 	return NULL;
85 }
86 
87 ZTEST_SUITE(riscv_gp, NULL, userspace_setup, NULL, NULL, NULL);
88