1 /*
2  * Copyright (c) 2001-2019, Arm Limited and Contributors. All rights reserved.
3  *
4  * SPDX-License-Identifier: BSD-3-Clause
5  */
6 
7 
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 
12 #include <openssl/objects.h>
13 #include <openssl/pem.h>
14 #include <openssl/evp.h>
15 #include <openssl/rand.h>
16 #include <openssl/bn.h>
17 #include <openssl/aes.h>
18 #include <openssl/err.h>
19 #include "common_rsa_keypair.h"
20 #include "common_util_log.h"
21 #include "cc_pka_hw_plat_defs.h"
22 
23 /**
24  * @brief The CC_CommonGetKeyPair reads RSA private key from the file, along with retrieving the private key,
25  *    it also retrieves the public key.
26  *
27  * The function
28  * 1. Build RSA public key structure
29  * @param[out] pRsaPrivKey - the private key
30  * @param[in] PemEncryptedFileName_ptr - private key file
31  * @param[in] Key_ptr - passphrase string
32  *
33  */
34 /*********************************************************/
CC_CommonGetKeyPair(RSA ** pRsaKeyPair,int8_t * PemEncryptedFileName_ptr,int8_t * Key_ptr)35 int32_t CC_CommonGetKeyPair(RSA **pRsaKeyPair, int8_t *PemEncryptedFileName_ptr, int8_t *Key_ptr)
36 {
37     FILE *fp = NULL;
38 
39     if (PemEncryptedFileName_ptr == NULL) {
40         UTIL_LOG_ERR("Illegal RSA key pair or pwd file name\n");
41         return -1;
42     }
43 
44     fp = fopen (PemEncryptedFileName_ptr, "r");
45     if (fp == NULL) {
46         UTIL_LOG_ERR("Cannot open RSA file %s\n", PemEncryptedFileName_ptr);
47         return -1;
48     }
49 
50 
51     if ((PEM_read_RSAPrivateKey (fp, pRsaKeyPair, NULL, Key_ptr)) == NULL) {
52         UTIL_LOG_ERR("Cannot read RSA private key\n");
53         ERR_print_errors_fp(stderr);
54         fclose (fp);
55         return -1;
56     }
57 
58     fclose (fp);
59     return 0;
60 }
61 
62 /**
63  * @brief The CC_CommonGetPubKey reads RSA public key from the file.
64  *
65  * The function
66  * 1. Build RSA public key structure
67  * @param[out] pRsaPrivKey - the rsa key
68  * @param[in] PemEncryptedFileName_ptr - public key file name
69  *
70  */
71 /*********************************************************/
CC_CommonGetPubKey(RSA ** pRsaKeyPair,int8_t * PemEncryptedFileName_ptr)72 int32_t CC_CommonGetPubKey(RSA **pRsaKeyPair, int8_t *PemEncryptedFileName_ptr)
73 {
74     FILE *fp = NULL;
75 
76     if (PemEncryptedFileName_ptr == NULL) {
77         UTIL_LOG_ERR("Illegal RSA file name\n");
78         return -1;
79     }
80 
81     fp = fopen (PemEncryptedFileName_ptr, "r");
82     if (fp == NULL) {
83         UTIL_LOG_ERR("Cannot open RSA file %s\n", PemEncryptedFileName_ptr);
84         return -1;
85     }
86 
87 
88     if ((PEM_read_RSA_PUBKEY(fp, pRsaKeyPair, NULL, NULL)) == NULL) {
89         UTIL_LOG_ERR("Cannot read RSA public key\n");
90         ERR_print_errors_fp(stderr);
91         fclose (fp);
92         return -1;
93     }
94 
95     fclose (fp);
96     return 0;
97 }
98 
99 /**
100 * @brief The function CC_CommonRsaCalculateNp calculates the Np it returns it as array of ascii's
101 *
102 * @param[in] N_ptr - public key N, represented as array of ascii's (0xbc is translated
103 *                    to 0x62 0x63)
104 * @param[out] NP_ptr - The NP result. NP size is NP_SIZE_IN_BYTES*2 + 1
105 *
106 */
107 /*********************************************************/
CC_CommonRsaCalculateNp(const int8_t * N_ptr,int8_t * NP_ptr)108 SBUEXPORT_C int32_t CC_CommonRsaCalculateNp(const int8_t *N_ptr,
109                       int8_t *NP_ptr)
110 {
111     int8_t *N_Temp = NULL;
112     int32_t  status  = -1;
113     BIGNUM *bn_n = BN_new();
114 
115     if ((NULL == N_ptr) || (NULL == NP_ptr)) {
116         UTIL_LOG_ERR("Illegal input\n");
117         goto calcNp_end;
118     }
119 
120     /* Copy the N to temporary N, allocate temporary N in N size + 2 */
121     N_Temp= (int8_t *)malloc ((SB_CERT_RSA_KEY_SIZE_IN_BYTES * 2 + 2) * sizeof (int8_t));
122     if (NULL == N_Temp) {
123         UTIL_LOG_ERR("failed to malloc.\n");
124         goto calcNp_end;
125     }
126 
127     if (NULL == bn_n) {
128         UTIL_LOG_ERR("failed to BN_new.\n");
129         goto calcNp_end;
130     }
131 
132     /* set the temporary N to 0 */
133     memset(N_Temp, 0, (SB_CERT_RSA_KEY_SIZE_IN_BYTES * 2 + 2));
134 
135     /* Copy the N to temp N */
136     memcpy (N_Temp, N_ptr, SB_CERT_RSA_KEY_SIZE_IN_BYTES * 2);
137 
138     if (!BN_hex2bn (&bn_n, N_Temp)) {
139         UTIL_LOG_ERR("BN_hex2bn failed.\n");
140         goto calcNp_end;
141     }
142 
143     if (CC_CommonRSACalculateNpInt(bn_n, NP_ptr, NP_HEX) != 0) {
144         UTIL_LOG_ERR("CC_CommonRSACalculateNpInt failed.\n");
145         goto calcNp_end;
146     }
147 
148     status = 0;
149 
150     calcNp_end:
151     if (N_Temp != NULL) {
152         free(N_Temp);
153     }
154     if (bn_n != NULL) {
155         BN_free (bn_n);
156     }
157     return(status);
158 }
159 
160 /**
161  * @brief The function calculates Np when given N as BIGNUM.
162  *
163  * @param[in] n - modulus as BIGNUM ptr
164  * @param[out] NP_ptr - the Np
165  *
166  */
167 /*********************************************************/
CC_CommonRSACalculateNpInt(BIGNUM * n,uint8_t * NP_ptr,NP_RESULT_TYPE_t resultType)168 SBUEXPORT_C int32_t CC_CommonRSACalculateNpInt(BIGNUM *n,
169                          uint8_t *NP_ptr,
170                          NP_RESULT_TYPE_t resultType)
171 {
172     int32_t len;
173     uint8_t *NP_res = NULL, *NP_resTemp = NULL;
174     int32_t  status  = -1;
175     BN_CTX *bn_ctx = BN_CTX_new();
176 
177     BIGNUM *bn_r   = BN_new();
178     BIGNUM *bn_a   = BN_new();
179     BIGNUM *bn_p   = BN_new();
180     BIGNUM *bn_n   = BN_new();
181     BIGNUM *bn_quo = BN_new();
182     BIGNUM *bn_rem = BN_new();
183 
184     if ((NULL == n) || (NULL == NP_ptr)) {
185         UTIL_LOG_ERR("Illegal input parameters.\n");
186         goto calcNpInt_end;
187     }
188 
189     NP_res = (int8_t*)malloc(NP_SIZE_IN_BYTES);
190     if (NP_res == NULL) {
191         UTIL_LOG_ERR("failed to malloc.\n");
192         goto calcNpInt_end;
193     }
194     if ((NULL == bn_r) ||
195         (NULL == bn_a) ||
196         (NULL == bn_p) ||
197         (NULL == bn_n) ||
198         (NULL == bn_quo) ||
199         (NULL == bn_rem) ||
200         (NULL == bn_ctx)) {
201         UTIL_LOG_ERR("failed to BN_new or BN_CTX_new.\n");
202         goto calcNpInt_end;
203     }
204 
205     /* computes a = 2^SNP */
206     BN_set_word (bn_a, 2);
207     BN_set_word (bn_p, SNP);
208     if (!BN_exp (bn_r, bn_a, bn_p, bn_ctx)) {
209         UTIL_LOG_ERR("failed to BN_exp.\n");
210         goto calcNpInt_end;
211     }
212     if (!BN_div (bn_quo, bn_rem, bn_r, n, bn_ctx)) {
213         UTIL_LOG_ERR("failed to BN_div.\n");
214         goto calcNpInt_end;
215     }
216 
217     if (resultType == NP_BIN) {
218         len = BN_bn2bin (bn_quo, NP_res);
219 
220         /* Set the output with 0 and than copy the result */
221         memset (NP_ptr, 0, NP_SIZE_IN_BYTES);
222         memcpy ((uint8_t *)(NP_ptr + (NP_SIZE_IN_BYTES - len)), (int8_t *)NP_res, len);
223     } else { /* resultType == HEX*/
224         NP_resTemp = BN_bn2hex (bn_quo);
225         if (NP_resTemp == NULL) {
226             UTIL_LOG_ERR("BN_bn2hex failed\n");
227             goto calcNpInt_end;
228         }
229         if (NP_resTemp[0] == '-'){
230             UTIL_LOG_ERR("BN_bn2hex returned negative values\n");
231             goto calcNpInt_end;
232         }
233         len = (int32_t)strlen (NP_resTemp);
234         memcpy(NP_res, NP_resTemp, len);
235 
236         /* Set the output with 0 and than copy the result */
237         memset (NP_ptr, 0, (NP_SIZE_IN_BYTES * 2 + 2));
238         memcpy ((int8_t *)(NP_ptr + (NP_SIZE_IN_BYTES * 2 + 2 - len)), (int8_t *)NP_res, len);
239     }
240 
241     status = 0;
242 
243     calcNpInt_end:
244     if (NP_res != NULL) {
245         free(NP_res);
246     }
247     if (bn_r != NULL) {
248         BN_free (bn_r);
249     }
250     if (bn_a != NULL) {
251         BN_free (bn_a);
252     }
253     if (bn_p != NULL) {
254         BN_free (bn_p);
255     }
256     if (bn_n != NULL) {
257         BN_free (bn_n);
258     }
259     if (bn_quo != NULL) {
260         BN_free (bn_quo);
261     }
262     if (bn_rem != NULL) {
263         BN_free (bn_rem);
264     }
265     if (bn_ctx != NULL) {
266         BN_CTX_free(bn_ctx);
267     }
268     if (NP_resTemp != NULL){
269         OPENSSL_free(NP_resTemp);
270     }
271     return(status);
272 }
273 
274 
275