1 /*
2  * Multi-precision integer library
3  * ESP32 S2 hardware accelerated parts based on mbedTLS implementation
4  *
5  * SPDX-FileCopyrightText: The Mbed TLS Contributors
6  *
7  * SPDX-License-Identifier: Apache-2.0
8  *
9  * SPDX-FileContributor: 2016-2022 Espressif Systems (Shanghai) CO LTD
10  */
11 #include "soc/hwcrypto_periph.h"
12 #include "esp_private/periph_ctrl.h"
13 #include <mbedtls/bignum.h>
14 #include "bignum_impl.h"
15 #include "soc/dport_reg.h"
16 #include "soc/periph_defs.h"
17 #include <sys/param.h>
18 #include "esp_crypto_lock.h"
19 
esp_mpi_hardware_words(size_t words)20 size_t esp_mpi_hardware_words(size_t words)
21 {
22     return words;
23 }
24 
esp_mpi_enable_hardware_hw_op(void)25 void esp_mpi_enable_hardware_hw_op( void )
26 {
27     esp_crypto_mpi_lock_acquire();
28 
29     /* Enable RSA hardware */
30     periph_module_enable(PERIPH_RSA_MODULE);
31 
32     DPORT_REG_CLR_BIT(DPORT_RSA_PD_CTRL_REG, DPORT_RSA_MEM_PD);
33 
34     while (DPORT_REG_READ(RSA_QUERY_CLEAN_REG) != 1) {
35     }
36     // Note: from enabling RSA clock to here takes about 1.3us
37 
38     REG_WRITE(RSA_INTERRUPT_REG, 0);
39 }
40 
esp_mpi_disable_hardware_hw_op(void)41 void esp_mpi_disable_hardware_hw_op( void )
42 {
43     DPORT_REG_SET_BIT(DPORT_RSA_PD_CTRL_REG, DPORT_RSA_PD);
44 
45     /* Disable RSA hardware */
46     periph_module_disable(PERIPH_RSA_MODULE);
47 
48     esp_crypto_mpi_lock_release();
49 }
50 
esp_mpi_interrupt_enable(bool enable)51 void esp_mpi_interrupt_enable( bool enable )
52 {
53     REG_WRITE(RSA_INTERRUPT_REG, enable);
54 }
55 
esp_mpi_interrupt_clear(void)56 void esp_mpi_interrupt_clear( void )
57 {
58     REG_WRITE(RSA_CLEAR_INTERRUPT_REG, 1);
59 }
60 
61 /* Copy mbedTLS MPI bignum 'mpi' to hardware memory block at 'mem_base'.
62 
63    If num_words is higher than the number of words in the bignum then
64    these additional words will be zeroed in the memory buffer.
65 */
mpi_to_mem_block(uint32_t mem_base,const mbedtls_mpi * mpi,size_t num_words)66 static inline void mpi_to_mem_block(uint32_t mem_base, const mbedtls_mpi *mpi, size_t num_words)
67 {
68     uint32_t *pbase = (uint32_t *)mem_base;
69     uint32_t copy_words = MIN(num_words, mpi->MBEDTLS_PRIVATE(n));
70 
71     /* Copy MPI data to memory block registers */
72     for (uint32_t i = 0; i < copy_words; i++) {
73         pbase[i] = mpi->MBEDTLS_PRIVATE(p)[i];
74     }
75 
76     /* Zero any remaining memory block data */
77     for (uint32_t i = copy_words; i < num_words; i++) {
78         pbase[i] = 0;
79     }
80 }
81 
82 /* Read mbedTLS MPI bignum back from hardware memory block.
83 
84    Reads num_words words from block.
85 */
mem_block_to_mpi(mbedtls_mpi * x,uint32_t mem_base,int num_words)86 static inline void mem_block_to_mpi(mbedtls_mpi *x, uint32_t mem_base, int num_words)
87 {
88 
89     /* Copy data from memory block registers */
90     esp_dport_access_read_buffer(x->MBEDTLS_PRIVATE(p), mem_base, num_words);
91     /* Zero any remaining limbs in the bignum, if the buffer is bigger
92        than num_words */
93     for (size_t i = num_words; i < x->MBEDTLS_PRIVATE(n); i++) {
94         x->MBEDTLS_PRIVATE(p)[i] = 0;
95     }
96 }
97 
98 
99 
100 /* Begin an RSA operation. op_reg specifies which 'START' register
101    to write to.
102 */
start_op(uint32_t op_reg)103 static inline void start_op(uint32_t op_reg)
104 {
105     /* Clear interrupt status */
106     DPORT_REG_WRITE(RSA_CLEAR_INTERRUPT_REG, 1);
107 
108     /* Note: above REG_WRITE includes a memw, so we know any writes
109        to the memory blocks are also complete. */
110 
111     DPORT_REG_WRITE(op_reg, 1);
112 }
113 
114 /* Wait for an RSA operation to complete.
115 */
wait_op_complete(void)116 static inline void wait_op_complete(void)
117 {
118     while (DPORT_REG_READ(RSA_QUERY_INTERRUPT_REG) != 1)
119     { }
120 
121     /* clear the interrupt */
122     DPORT_REG_WRITE(RSA_CLEAR_INTERRUPT_REG, 1);
123 }
124 
125 
126 /* Read result from last MPI operation */
esp_mpi_read_result_hw_op(mbedtls_mpi * Z,size_t z_words)127 void esp_mpi_read_result_hw_op(mbedtls_mpi *Z, size_t z_words)
128 {
129     wait_op_complete();
130     mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, z_words);
131 }
132 
133 
134 /* Z = (X * Y) mod M
135 
136    Not an mbedTLS function
137 */
esp_mpi_mul_mpi_mod_hw_op(const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M,const mbedtls_mpi * Rinv,mbedtls_mpi_uint Mprime,size_t num_words)138 void esp_mpi_mul_mpi_mod_hw_op(const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M, const mbedtls_mpi *Rinv, mbedtls_mpi_uint Mprime, size_t num_words)
139 {
140     DPORT_REG_WRITE(RSA_LENGTH_REG, (num_words - 1));
141 
142     /* Load M, X, Rinv, Mprime (Mprime is mod 2^32) */
143     mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
144     mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words);
145     mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words);
146     mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, Rinv, num_words);
147     DPORT_REG_WRITE(RSA_M_DASH_REG, Mprime);
148 
149     start_op(RSA_MOD_MULT_START_REG);
150 }
151 
152 /* Z = (X ^ Y) mod M
153 */
esp_mpi_exp_mpi_mod_hw_op(const mbedtls_mpi * X,const mbedtls_mpi * Y,const mbedtls_mpi * M,const mbedtls_mpi * Rinv,mbedtls_mpi_uint Mprime,size_t num_words)154 void esp_mpi_exp_mpi_mod_hw_op(const mbedtls_mpi *X, const mbedtls_mpi *Y, const mbedtls_mpi *M, const mbedtls_mpi *Rinv, mbedtls_mpi_uint Mprime, size_t num_words)
155 {
156     size_t y_bits = mbedtls_mpi_bitlen(Y);
157 
158     DPORT_REG_WRITE(RSA_LENGTH_REG, (num_words - 1));
159 
160     /* Load M, X, Rinv, Mprime (Mprime is mod 2^32) */
161     mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
162     mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words);
163     mpi_to_mem_block(RSA_MEM_M_BLOCK_BASE, M, num_words);
164     mpi_to_mem_block(RSA_MEM_RB_BLOCK_BASE, Rinv, num_words);
165     DPORT_REG_WRITE(RSA_M_DASH_REG, Mprime);
166 
167     /* Enable acceleration options */
168     DPORT_REG_WRITE(RSA_CONSTANT_TIME_REG, 0);
169     DPORT_REG_WRITE(RSA_SEARCH_OPEN_REG, 1);
170     DPORT_REG_WRITE(RSA_SEARCH_POS_REG, y_bits - 1);
171 
172     /* Execute first stage montgomery multiplication */
173     start_op(RSA_MODEXP_START_REG);
174 
175     DPORT_REG_WRITE(RSA_SEARCH_OPEN_REG, 0);
176 }
177 
178 
179 /* Z = X * Y */
esp_mpi_mul_mpi_hw_op(const mbedtls_mpi * X,const mbedtls_mpi * Y,size_t num_words)180 void esp_mpi_mul_mpi_hw_op(const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t num_words)
181 {
182     /* Copy X (right-extended) & Y (left-extended) to memory block */
183     mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
184     mpi_to_mem_block(RSA_MEM_Z_BLOCK_BASE + num_words * 4, Y, num_words);
185     /* NB: as Y is left-extended, we don't zero the bottom words_mult words of Y block.
186        This is OK for now because zeroing is done by hardware when we do esp_mpi_acquire_hardware().
187     */
188     DPORT_REG_WRITE(RSA_LENGTH_REG, (num_words * 2 - 1));
189     start_op(RSA_MULT_START_REG);
190 }
191 
192 
193 
194 /**
195  * @brief Special-case of (X * Y), where we use hardware montgomery mod
196    multiplication to calculate result where either A or B are >2048 bits so
197    can't use the standard multiplication method.
198  *
199  */
esp_mpi_mult_mpi_failover_mod_mult_hw_op(const mbedtls_mpi * X,const mbedtls_mpi * Y,size_t num_words)200 void esp_mpi_mult_mpi_failover_mod_mult_hw_op(const mbedtls_mpi *X, const mbedtls_mpi *Y, size_t num_words)
201 {
202     /* M = 2^num_words - 1, so block is entirely FF */
203     for (size_t i = 0; i < num_words; i++) {
204         DPORT_REG_WRITE(RSA_MEM_M_BLOCK_BASE + i * 4, UINT32_MAX);
205     }
206 
207     /* Mprime = 1 */
208     DPORT_REG_WRITE(RSA_M_DASH_REG, 1);
209     DPORT_REG_WRITE(RSA_LENGTH_REG, num_words - 1);
210 
211     /* Load X & Y */
212     mpi_to_mem_block(RSA_MEM_X_BLOCK_BASE, X, num_words);
213     mpi_to_mem_block(RSA_MEM_Y_BLOCK_BASE, Y, num_words);
214 
215     /* Rinv = 1, write first word */
216     DPORT_REG_WRITE(RSA_MEM_RB_BLOCK_BASE, 1);
217 
218      /* Zero out rest of the Rinv words */
219     for (size_t i = 1; i < num_words; i++) {
220         DPORT_REG_WRITE(RSA_MEM_RB_BLOCK_BASE + i * 4, 0);
221     }
222 
223     start_op(RSA_MOD_MULT_START_REG);
224 }
225