1 /*
2  *  The LMS stateful-hash public-key signature scheme
3  *
4  *  Copyright The Mbed TLS Contributors
5  *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
6  */
7 
8 /*
9  *  The following sources were referenced in the design of this implementation
10  *  of the LMS algorithm:
11  *
12  *  [1] IETF RFC8554
13  *      D. McGrew, M. Curcio, S.Fluhrer
14  *      https://datatracker.ietf.org/doc/html/rfc8554
15  *
16  *  [2] NIST Special Publication 800-208
17  *      David A. Cooper et. al.
18  *      https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-208.pdf
19  */
20 
21 #include "common.h"
22 
23 #if defined(MBEDTLS_LMS_C)
24 
25 #include <string.h>
26 
27 #include "lmots.h"
28 
29 #include "psa/crypto.h"
30 #include "psa_util_internal.h"
31 #include "mbedtls/lms.h"
32 #include "mbedtls/error.h"
33 #include "mbedtls/platform_util.h"
34 
35 #include "mbedtls/platform.h"
36 
37 /* Define a local translating function to save code size by not using too many
38  * arguments in each translating place. */
local_err_translation(psa_status_t status)39 static int local_err_translation(psa_status_t status)
40 {
41     return psa_status_to_mbedtls(status, psa_to_lms_errors,
42                                  ARRAY_LENGTH(psa_to_lms_errors),
43                                  psa_generic_status_to_mbedtls);
44 }
45 #define PSA_TO_MBEDTLS_ERR(status) local_err_translation(status)
46 
47 #define SIG_Q_LEAF_ID_OFFSET     (0)
48 #define SIG_OTS_SIG_OFFSET       (SIG_Q_LEAF_ID_OFFSET + \
49                                   MBEDTLS_LMOTS_Q_LEAF_ID_LEN)
50 #define SIG_TYPE_OFFSET(otstype) (SIG_OTS_SIG_OFFSET   + \
51                                   MBEDTLS_LMOTS_SIG_LEN(otstype))
52 #define SIG_PATH_OFFSET(otstype) (SIG_TYPE_OFFSET(otstype) + \
53                                   MBEDTLS_LMS_TYPE_LEN)
54 
55 #define PUBLIC_KEY_TYPE_OFFSET      (0)
56 #define PUBLIC_KEY_OTSTYPE_OFFSET   (PUBLIC_KEY_TYPE_OFFSET + \
57                                      MBEDTLS_LMS_TYPE_LEN)
58 #define PUBLIC_KEY_I_KEY_ID_OFFSET  (PUBLIC_KEY_OTSTYPE_OFFSET  + \
59                                      MBEDTLS_LMOTS_TYPE_LEN)
60 #define PUBLIC_KEY_ROOT_NODE_OFFSET (PUBLIC_KEY_I_KEY_ID_OFFSET + \
61                                      MBEDTLS_LMOTS_I_KEY_ID_LEN)
62 
63 
64 /* Currently only support H=10 */
65 #define H_TREE_HEIGHT_MAX                  10
66 #define MERKLE_TREE_NODE_AM(type)          ((size_t) 1 << (MBEDTLS_LMS_H_TREE_HEIGHT(type) + 1u))
67 #define MERKLE_TREE_LEAF_NODE_AM(type)     ((size_t) 1 << MBEDTLS_LMS_H_TREE_HEIGHT(type))
68 #define MERKLE_TREE_INTERNAL_NODE_AM(type) ((size_t) 1 << MBEDTLS_LMS_H_TREE_HEIGHT(type))
69 
70 #define D_CONST_LEN           (2)
71 static const unsigned char D_LEAF_CONSTANT_BYTES[D_CONST_LEN] = { 0x82, 0x82 };
72 static const unsigned char D_INTR_CONSTANT_BYTES[D_CONST_LEN] = { 0x83, 0x83 };
73 
74 
75 /* Calculate the value of a leaf node of the Merkle tree (which is a hash of a
76  * public key and some other parameters like the leaf index). This function
77  * implements RFC8554 section 5.3, in the case where r >= 2^h.
78  *
79  *  params              The LMS parameter set, the underlying LMOTS
80  *                      parameter set, and I value which describe the key
81  *                      being used.
82  *
83  *  pub_key             The public key of the private whose index
84  *                      corresponds to the index of this leaf node. This
85  *                      is a hash output.
86  *
87  *  r_node_idx          The index of this node in the Merkle tree. Note
88  *                      that the root node of the Merkle tree is
89  *                      1-indexed.
90  *
91  *  out                 The output node value, which is a hash output.
92  */
create_merkle_leaf_value(const mbedtls_lms_parameters_t * params,unsigned char * pub_key,unsigned int r_node_idx,unsigned char * out)93 static int create_merkle_leaf_value(const mbedtls_lms_parameters_t *params,
94                                     unsigned char *pub_key,
95                                     unsigned int r_node_idx,
96                                     unsigned char *out)
97 {
98     psa_hash_operation_t op;
99     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
100     size_t output_hash_len;
101     unsigned char r_node_idx_bytes[4];
102 
103     op = psa_hash_operation_init();
104     status = psa_hash_setup(&op, PSA_ALG_SHA_256);
105     if (status != PSA_SUCCESS) {
106         goto exit;
107     }
108 
109     status = psa_hash_update(&op, params->I_key_identifier,
110                              MBEDTLS_LMOTS_I_KEY_ID_LEN);
111     if (status != PSA_SUCCESS) {
112         goto exit;
113     }
114 
115     mbedtls_lms_unsigned_int_to_network_bytes(r_node_idx, 4, r_node_idx_bytes);
116     status = psa_hash_update(&op, r_node_idx_bytes, 4);
117     if (status != PSA_SUCCESS) {
118         goto exit;
119     }
120 
121     status = psa_hash_update(&op, D_LEAF_CONSTANT_BYTES, D_CONST_LEN);
122     if (status != PSA_SUCCESS) {
123         goto exit;
124     }
125 
126     status = psa_hash_update(&op, pub_key,
127                              MBEDTLS_LMOTS_N_HASH_LEN(params->otstype));
128     if (status != PSA_SUCCESS) {
129         goto exit;
130     }
131 
132     status = psa_hash_finish(&op, out, MBEDTLS_LMS_M_NODE_BYTES(params->type),
133                              &output_hash_len);
134     if (status != PSA_SUCCESS) {
135         goto exit;
136     }
137 
138 exit:
139     psa_hash_abort(&op);
140 
141     return PSA_TO_MBEDTLS_ERR(status);
142 }
143 
144 /* Calculate the value of an internal node of the Merkle tree (which is a hash
145  * of a public key and some other parameters like the node index). This function
146  * implements RFC8554 section 5.3, in the case where r < 2^h.
147  *
148  *  params              The LMS parameter set, the underlying LMOTS
149  *                      parameter set, and I value which describe the key
150  *                      being used.
151  *
152  *  left_node           The value of the child of this node which is on
153  *                      the left-hand side. As with all nodes on the
154  *                      Merkle tree, this is a hash output.
155  *
156  *  right_node          The value of the child of this node which is on
157  *                      the right-hand side. As with all nodes on the
158  *                      Merkle tree, this is a hash output.
159  *
160  *  r_node_idx          The index of this node in the Merkle tree. Note
161  *                      that the root node of the Merkle tree is
162  *                      1-indexed.
163  *
164  *  out                 The output node value, which is a hash output.
165  */
create_merkle_internal_value(const mbedtls_lms_parameters_t * params,const unsigned char * left_node,const unsigned char * right_node,unsigned int r_node_idx,unsigned char * out)166 static int create_merkle_internal_value(const mbedtls_lms_parameters_t *params,
167                                         const unsigned char *left_node,
168                                         const unsigned char *right_node,
169                                         unsigned int r_node_idx,
170                                         unsigned char *out)
171 {
172     psa_hash_operation_t op;
173     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
174     size_t output_hash_len;
175     unsigned char r_node_idx_bytes[4];
176 
177     op = psa_hash_operation_init();
178     status = psa_hash_setup(&op, PSA_ALG_SHA_256);
179     if (status != PSA_SUCCESS) {
180         goto exit;
181     }
182 
183     status = psa_hash_update(&op, params->I_key_identifier,
184                              MBEDTLS_LMOTS_I_KEY_ID_LEN);
185     if (status != PSA_SUCCESS) {
186         goto exit;
187     }
188 
189     mbedtls_lms_unsigned_int_to_network_bytes(r_node_idx, 4, r_node_idx_bytes);
190     status = psa_hash_update(&op, r_node_idx_bytes, 4);
191     if (status != PSA_SUCCESS) {
192         goto exit;
193     }
194 
195     status = psa_hash_update(&op, D_INTR_CONSTANT_BYTES, D_CONST_LEN);
196     if (status != PSA_SUCCESS) {
197         goto exit;
198     }
199 
200     status = psa_hash_update(&op, left_node,
201                              MBEDTLS_LMS_M_NODE_BYTES(params->type));
202     if (status != PSA_SUCCESS) {
203         goto exit;
204     }
205 
206     status = psa_hash_update(&op, right_node,
207                              MBEDTLS_LMS_M_NODE_BYTES(params->type));
208     if (status != PSA_SUCCESS) {
209         goto exit;
210     }
211 
212     status = psa_hash_finish(&op, out, MBEDTLS_LMS_M_NODE_BYTES(params->type),
213                              &output_hash_len);
214     if (status != PSA_SUCCESS) {
215         goto exit;
216     }
217 
218 exit:
219     psa_hash_abort(&op);
220 
221     return PSA_TO_MBEDTLS_ERR(status);
222 }
223 
mbedtls_lms_public_init(mbedtls_lms_public_t * ctx)224 void mbedtls_lms_public_init(mbedtls_lms_public_t *ctx)
225 {
226     memset(ctx, 0, sizeof(*ctx));
227 }
228 
mbedtls_lms_public_free(mbedtls_lms_public_t * ctx)229 void mbedtls_lms_public_free(mbedtls_lms_public_t *ctx)
230 {
231     mbedtls_platform_zeroize(ctx, sizeof(*ctx));
232 }
233 
mbedtls_lms_import_public_key(mbedtls_lms_public_t * ctx,const unsigned char * key,size_t key_size)234 int mbedtls_lms_import_public_key(mbedtls_lms_public_t *ctx,
235                                   const unsigned char *key, size_t key_size)
236 {
237     mbedtls_lms_algorithm_type_t type;
238     mbedtls_lmots_algorithm_type_t otstype;
239 
240     type = (mbedtls_lms_algorithm_type_t) mbedtls_lms_network_bytes_to_unsigned_int(
241         MBEDTLS_LMS_TYPE_LEN,
242         key +
243         PUBLIC_KEY_TYPE_OFFSET);
244     if (type != MBEDTLS_LMS_SHA256_M32_H10) {
245         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
246     }
247     ctx->params.type = type;
248 
249     if (key_size != MBEDTLS_LMS_PUBLIC_KEY_LEN(ctx->params.type)) {
250         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
251     }
252 
253     otstype = (mbedtls_lmots_algorithm_type_t) mbedtls_lms_network_bytes_to_unsigned_int(
254         MBEDTLS_LMOTS_TYPE_LEN,
255         key +
256         PUBLIC_KEY_OTSTYPE_OFFSET);
257     if (otstype != MBEDTLS_LMOTS_SHA256_N32_W8) {
258         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
259     }
260     ctx->params.otstype = otstype;
261 
262     memcpy(ctx->params.I_key_identifier,
263            key + PUBLIC_KEY_I_KEY_ID_OFFSET,
264            MBEDTLS_LMOTS_I_KEY_ID_LEN);
265     memcpy(ctx->T_1_pub_key, key + PUBLIC_KEY_ROOT_NODE_OFFSET,
266            MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type));
267 
268     ctx->have_public_key = 1;
269 
270     return 0;
271 }
272 
mbedtls_lms_export_public_key(const mbedtls_lms_public_t * ctx,unsigned char * key,size_t key_size,size_t * key_len)273 int mbedtls_lms_export_public_key(const mbedtls_lms_public_t *ctx,
274                                   unsigned char *key,
275                                   size_t key_size, size_t *key_len)
276 {
277     if (key_size < MBEDTLS_LMS_PUBLIC_KEY_LEN(ctx->params.type)) {
278         return MBEDTLS_ERR_LMS_BUFFER_TOO_SMALL;
279     }
280 
281     if (!ctx->have_public_key) {
282         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
283     }
284 
285     mbedtls_lms_unsigned_int_to_network_bytes(
286         ctx->params.type,
287         MBEDTLS_LMS_TYPE_LEN, key + PUBLIC_KEY_TYPE_OFFSET);
288     mbedtls_lms_unsigned_int_to_network_bytes(ctx->params.otstype,
289                                               MBEDTLS_LMOTS_TYPE_LEN,
290                                               key + PUBLIC_KEY_OTSTYPE_OFFSET);
291     memcpy(key + PUBLIC_KEY_I_KEY_ID_OFFSET,
292            ctx->params.I_key_identifier,
293            MBEDTLS_LMOTS_I_KEY_ID_LEN);
294     memcpy(key +PUBLIC_KEY_ROOT_NODE_OFFSET,
295            ctx->T_1_pub_key,
296            MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type));
297 
298     if (key_len != NULL) {
299         *key_len = MBEDTLS_LMS_PUBLIC_KEY_LEN(ctx->params.type);
300     }
301 
302     return 0;
303 }
304 
mbedtls_lms_verify(const mbedtls_lms_public_t * ctx,const unsigned char * msg,size_t msg_size,const unsigned char * sig,size_t sig_size)305 int mbedtls_lms_verify(const mbedtls_lms_public_t *ctx,
306                        const unsigned char *msg, size_t msg_size,
307                        const unsigned char *sig, size_t sig_size)
308 {
309     unsigned int q_leaf_identifier;
310     unsigned char Kc_candidate_ots_pub_key[MBEDTLS_LMOTS_N_HASH_LEN_MAX];
311     unsigned char Tc_candidate_root_node[MBEDTLS_LMS_M_NODE_BYTES_MAX];
312     unsigned int height;
313     unsigned int curr_node_id;
314     unsigned int parent_node_id;
315     const unsigned char *left_node;
316     const unsigned char *right_node;
317     mbedtls_lmots_parameters_t ots_params;
318     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
319 
320     if (!ctx->have_public_key) {
321         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
322     }
323 
324     if (ctx->params.type
325         != MBEDTLS_LMS_SHA256_M32_H10) {
326         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
327     }
328 
329     if (ctx->params.otstype
330         != MBEDTLS_LMOTS_SHA256_N32_W8) {
331         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
332     }
333 
334     if (sig_size != MBEDTLS_LMS_SIG_LEN(ctx->params.type, ctx->params.otstype)) {
335         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
336     }
337 
338     if (sig_size < SIG_OTS_SIG_OFFSET + MBEDTLS_LMOTS_TYPE_LEN) {
339         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
340     }
341 
342     if (mbedtls_lms_network_bytes_to_unsigned_int(MBEDTLS_LMOTS_TYPE_LEN,
343                                                   sig + SIG_OTS_SIG_OFFSET +
344                                                   MBEDTLS_LMOTS_SIG_TYPE_OFFSET)
345         != MBEDTLS_LMOTS_SHA256_N32_W8) {
346         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
347     }
348 
349     if (sig_size < SIG_TYPE_OFFSET(ctx->params.otstype) + MBEDTLS_LMS_TYPE_LEN) {
350         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
351     }
352 
353     if (mbedtls_lms_network_bytes_to_unsigned_int(MBEDTLS_LMS_TYPE_LEN,
354                                                   sig + SIG_TYPE_OFFSET(ctx->params.otstype))
355         != MBEDTLS_LMS_SHA256_M32_H10) {
356         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
357     }
358 
359 
360     q_leaf_identifier = mbedtls_lms_network_bytes_to_unsigned_int(
361         MBEDTLS_LMOTS_Q_LEAF_ID_LEN, sig + SIG_Q_LEAF_ID_OFFSET);
362 
363     if (q_leaf_identifier >= MERKLE_TREE_LEAF_NODE_AM(ctx->params.type)) {
364         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
365     }
366 
367     memcpy(ots_params.I_key_identifier,
368            ctx->params.I_key_identifier,
369            MBEDTLS_LMOTS_I_KEY_ID_LEN);
370     mbedtls_lms_unsigned_int_to_network_bytes(q_leaf_identifier,
371                                               MBEDTLS_LMOTS_Q_LEAF_ID_LEN,
372                                               ots_params.q_leaf_identifier);
373     ots_params.type = ctx->params.otstype;
374 
375     ret = mbedtls_lmots_calculate_public_key_candidate(&ots_params,
376                                                        msg,
377                                                        msg_size,
378                                                        sig + SIG_OTS_SIG_OFFSET,
379                                                        MBEDTLS_LMOTS_SIG_LEN(ctx->params.otstype),
380                                                        Kc_candidate_ots_pub_key,
381                                                        sizeof(Kc_candidate_ots_pub_key),
382                                                        NULL);
383     if (ret != 0) {
384         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
385     }
386 
387     create_merkle_leaf_value(
388         &ctx->params,
389         Kc_candidate_ots_pub_key,
390         MERKLE_TREE_INTERNAL_NODE_AM(ctx->params.type) + q_leaf_identifier,
391         Tc_candidate_root_node);
392 
393     curr_node_id = MERKLE_TREE_INTERNAL_NODE_AM(ctx->params.type) +
394                    q_leaf_identifier;
395 
396     for (height = 0; height < MBEDTLS_LMS_H_TREE_HEIGHT(ctx->params.type);
397          height++) {
398         parent_node_id = curr_node_id / 2;
399 
400         /* Left/right node ordering matters for the hash */
401         if (curr_node_id & 1) {
402             left_node = sig + SIG_PATH_OFFSET(ctx->params.otstype) +
403                         height * MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type);
404             right_node = Tc_candidate_root_node;
405         } else {
406             left_node = Tc_candidate_root_node;
407             right_node = sig + SIG_PATH_OFFSET(ctx->params.otstype) +
408                          height * MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type);
409         }
410 
411         create_merkle_internal_value(&ctx->params, left_node, right_node,
412                                      parent_node_id, Tc_candidate_root_node);
413 
414         curr_node_id /= 2;
415     }
416 
417     if (memcmp(Tc_candidate_root_node, ctx->T_1_pub_key,
418                MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type))) {
419         return MBEDTLS_ERR_LMS_VERIFY_FAILED;
420     }
421 
422     return 0;
423 }
424 
425 #if defined(MBEDTLS_LMS_PRIVATE)
426 
427 /* Calculate a full Merkle tree based on a private key. This function
428  * implements RFC8554 section 5.3, and is used to generate a public key (as the
429  * public key is the root node of the Merkle tree).
430  *
431  *  ctx                 The LMS private context, containing a parameter
432  *                      set and private key material consisting of both
433  *                      public and private OTS.
434  *
435  *  tree                The output tree, which is 2^(H + 1) hash outputs.
436  *                      In the case of H=10 we have 2048 tree nodes (of
437  *                      which 1024 of them are leaf nodes). Note that
438  *                      because the Merkle tree root is 1-indexed, the 0
439  *                      index tree node is never used.
440  */
calculate_merkle_tree(const mbedtls_lms_private_t * ctx,unsigned char * tree)441 static int calculate_merkle_tree(const mbedtls_lms_private_t *ctx,
442                                  unsigned char *tree)
443 {
444     unsigned int priv_key_idx;
445     unsigned int r_node_idx;
446     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
447 
448     /* First create the leaf nodes, in ascending order */
449     for (priv_key_idx = 0;
450          priv_key_idx < MERKLE_TREE_INTERNAL_NODE_AM(ctx->params.type);
451          priv_key_idx++) {
452         r_node_idx = MERKLE_TREE_INTERNAL_NODE_AM(ctx->params.type) + priv_key_idx;
453 
454         ret = create_merkle_leaf_value(&ctx->params,
455                                        ctx->ots_public_keys[priv_key_idx].public_key,
456                                        r_node_idx,
457                                        &tree[r_node_idx * MBEDTLS_LMS_M_NODE_BYTES(
458                                                  ctx->params.type)]);
459         if (ret != 0) {
460             return ret;
461         }
462     }
463 
464     /* Then the internal nodes, in reverse order so that we can guarantee the
465      * parent has been created */
466     for (r_node_idx = MERKLE_TREE_INTERNAL_NODE_AM(ctx->params.type) - 1;
467          r_node_idx > 0;
468          r_node_idx--) {
469         ret = create_merkle_internal_value(&ctx->params,
470                                            &tree[(r_node_idx * 2) *
471                                                  MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type)],
472                                            &tree[(r_node_idx * 2 + 1) *
473                                                  MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type)],
474                                            r_node_idx,
475                                            &tree[r_node_idx *
476                                                  MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type)]);
477         if (ret != 0) {
478             return ret;
479         }
480     }
481 
482     return 0;
483 }
484 
485 /* Calculate a path from a leaf node of the Merkle tree to the root of the tree,
486  * and return the full path. This function implements RFC8554 section 5.4.1, as
487  * the Merkle path is the main component of an LMS signature.
488  *
489  *  ctx                 The LMS private context, containing a parameter
490  *                      set and private key material consisting of both
491  *                      public and private OTS.
492  *
493  *  leaf_node_id        Which leaf node to calculate the path from.
494  *
495  *  path                The output path, which is H hash outputs.
496  */
get_merkle_path(mbedtls_lms_private_t * ctx,unsigned int leaf_node_id,unsigned char * path)497 static int get_merkle_path(mbedtls_lms_private_t *ctx,
498                            unsigned int leaf_node_id,
499                            unsigned char *path)
500 {
501     const size_t node_bytes = MBEDTLS_LMS_M_NODE_BYTES(ctx->params.type);
502     unsigned int curr_node_id = leaf_node_id;
503     unsigned int adjacent_node_id;
504     unsigned char *tree = NULL;
505     unsigned int height;
506     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
507 
508     tree = mbedtls_calloc(MERKLE_TREE_NODE_AM(ctx->params.type),
509                           node_bytes);
510     if (tree == NULL) {
511         return MBEDTLS_ERR_LMS_ALLOC_FAILED;
512     }
513 
514     ret = calculate_merkle_tree(ctx, tree);
515     if (ret != 0) {
516         goto exit;
517     }
518 
519     for (height = 0; height < MBEDTLS_LMS_H_TREE_HEIGHT(ctx->params.type);
520          height++) {
521         adjacent_node_id = curr_node_id ^ 1;
522 
523         memcpy(&path[height * node_bytes],
524                &tree[adjacent_node_id * node_bytes], node_bytes);
525 
526         curr_node_id >>= 1;
527     }
528 
529     ret = 0;
530 
531 exit:
532     mbedtls_zeroize_and_free(tree, node_bytes *
533                              MERKLE_TREE_NODE_AM(ctx->params.type));
534 
535     return ret;
536 }
537 
mbedtls_lms_private_init(mbedtls_lms_private_t * ctx)538 void mbedtls_lms_private_init(mbedtls_lms_private_t *ctx)
539 {
540     memset(ctx, 0, sizeof(*ctx));
541 }
542 
mbedtls_lms_private_free(mbedtls_lms_private_t * ctx)543 void mbedtls_lms_private_free(mbedtls_lms_private_t *ctx)
544 {
545     unsigned int idx;
546 
547     if (ctx->have_private_key) {
548         if (ctx->ots_private_keys != NULL) {
549             for (idx = 0; idx < MERKLE_TREE_LEAF_NODE_AM(ctx->params.type); idx++) {
550                 mbedtls_lmots_private_free(&ctx->ots_private_keys[idx]);
551             }
552         }
553 
554         if (ctx->ots_public_keys != NULL) {
555             for (idx = 0; idx < MERKLE_TREE_LEAF_NODE_AM(ctx->params.type); idx++) {
556                 mbedtls_lmots_public_free(&ctx->ots_public_keys[idx]);
557             }
558         }
559 
560         mbedtls_free(ctx->ots_private_keys);
561         mbedtls_free(ctx->ots_public_keys);
562     }
563 
564     mbedtls_platform_zeroize(ctx, sizeof(*ctx));
565 }
566 
567 
mbedtls_lms_generate_private_key(mbedtls_lms_private_t * ctx,mbedtls_lms_algorithm_type_t type,mbedtls_lmots_algorithm_type_t otstype,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng,const unsigned char * seed,size_t seed_size)568 int mbedtls_lms_generate_private_key(mbedtls_lms_private_t *ctx,
569                                      mbedtls_lms_algorithm_type_t type,
570                                      mbedtls_lmots_algorithm_type_t otstype,
571                                      int (*f_rng)(void *, unsigned char *, size_t),
572                                      void *p_rng, const unsigned char *seed,
573                                      size_t seed_size)
574 {
575     unsigned int idx = 0;
576     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
577 
578     if (type != MBEDTLS_LMS_SHA256_M32_H10) {
579         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
580     }
581 
582     if (otstype != MBEDTLS_LMOTS_SHA256_N32_W8) {
583         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
584     }
585 
586     if (ctx->have_private_key) {
587         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
588     }
589 
590     ctx->params.type = type;
591     ctx->params.otstype = otstype;
592     ctx->have_private_key = 1;
593 
594     ret = f_rng(p_rng,
595                 ctx->params.I_key_identifier,
596                 MBEDTLS_LMOTS_I_KEY_ID_LEN);
597     if (ret != 0) {
598         goto exit;
599     }
600 
601     /* Requires a cast to size_t to avoid an implicit cast warning on certain
602      * platforms (particularly Windows) */
603     ctx->ots_private_keys = mbedtls_calloc((size_t) MERKLE_TREE_LEAF_NODE_AM(ctx->params.type),
604                                            sizeof(*ctx->ots_private_keys));
605     if (ctx->ots_private_keys == NULL) {
606         ret = MBEDTLS_ERR_LMS_ALLOC_FAILED;
607         goto exit;
608     }
609 
610     /* Requires a cast to size_t to avoid an implicit cast warning on certain
611      * platforms (particularly Windows) */
612     ctx->ots_public_keys = mbedtls_calloc((size_t) MERKLE_TREE_LEAF_NODE_AM(ctx->params.type),
613                                           sizeof(*ctx->ots_public_keys));
614     if (ctx->ots_public_keys == NULL) {
615         ret = MBEDTLS_ERR_LMS_ALLOC_FAILED;
616         goto exit;
617     }
618 
619     for (idx = 0; idx < MERKLE_TREE_LEAF_NODE_AM(ctx->params.type); idx++) {
620         mbedtls_lmots_private_init(&ctx->ots_private_keys[idx]);
621         mbedtls_lmots_public_init(&ctx->ots_public_keys[idx]);
622     }
623 
624 
625     for (idx = 0; idx < MERKLE_TREE_LEAF_NODE_AM(ctx->params.type); idx++) {
626         ret = mbedtls_lmots_generate_private_key(&ctx->ots_private_keys[idx],
627                                                  otstype,
628                                                  ctx->params.I_key_identifier,
629                                                  idx, seed, seed_size);
630         if (ret != 0) {
631             goto exit;
632         }
633 
634         ret = mbedtls_lmots_calculate_public_key(&ctx->ots_public_keys[idx],
635                                                  &ctx->ots_private_keys[idx]);
636         if (ret != 0) {
637             goto exit;
638         }
639     }
640 
641     ctx->q_next_usable_key = 0;
642 
643 exit:
644     if (ret != 0) {
645         mbedtls_lms_private_free(ctx);
646     }
647 
648     return ret;
649 }
650 
mbedtls_lms_calculate_public_key(mbedtls_lms_public_t * ctx,const mbedtls_lms_private_t * priv_ctx)651 int mbedtls_lms_calculate_public_key(mbedtls_lms_public_t *ctx,
652                                      const mbedtls_lms_private_t *priv_ctx)
653 {
654     const size_t node_bytes = MBEDTLS_LMS_M_NODE_BYTES(priv_ctx->params.type);
655     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
656     unsigned char *tree = NULL;
657 
658     if (!priv_ctx->have_private_key) {
659         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
660     }
661 
662     if (priv_ctx->params.type
663         != MBEDTLS_LMS_SHA256_M32_H10) {
664         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
665     }
666 
667     if (priv_ctx->params.otstype
668         != MBEDTLS_LMOTS_SHA256_N32_W8) {
669         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
670     }
671 
672     tree = mbedtls_calloc(MERKLE_TREE_NODE_AM(priv_ctx->params.type),
673                           node_bytes);
674     if (tree == NULL) {
675         return MBEDTLS_ERR_LMS_ALLOC_FAILED;
676     }
677 
678     memcpy(&ctx->params, &priv_ctx->params,
679            sizeof(mbedtls_lmots_parameters_t));
680 
681     ret = calculate_merkle_tree(priv_ctx, tree);
682     if (ret != 0) {
683         goto exit;
684     }
685 
686     /* Root node is always at position 1, due to 1-based indexing */
687     memcpy(ctx->T_1_pub_key, &tree[node_bytes], node_bytes);
688 
689     ctx->have_public_key = 1;
690 
691     ret = 0;
692 
693 exit:
694     mbedtls_zeroize_and_free(tree, node_bytes *
695                              MERKLE_TREE_NODE_AM(priv_ctx->params.type));
696 
697     return ret;
698 }
699 
700 
mbedtls_lms_sign(mbedtls_lms_private_t * ctx,int (* f_rng)(void *,unsigned char *,size_t),void * p_rng,const unsigned char * msg,unsigned int msg_size,unsigned char * sig,size_t sig_size,size_t * sig_len)701 int mbedtls_lms_sign(mbedtls_lms_private_t *ctx,
702                      int (*f_rng)(void *, unsigned char *, size_t),
703                      void *p_rng, const unsigned char *msg,
704                      unsigned int msg_size, unsigned char *sig, size_t sig_size,
705                      size_t *sig_len)
706 {
707     uint32_t q_leaf_identifier;
708     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
709 
710     if (!ctx->have_private_key) {
711         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
712     }
713 
714     if (sig_size < MBEDTLS_LMS_SIG_LEN(ctx->params.type, ctx->params.otstype)) {
715         return MBEDTLS_ERR_LMS_BUFFER_TOO_SMALL;
716     }
717 
718     if (ctx->params.type != MBEDTLS_LMS_SHA256_M32_H10) {
719         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
720     }
721 
722     if (ctx->params.otstype
723         != MBEDTLS_LMOTS_SHA256_N32_W8) {
724         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
725     }
726 
727     if (ctx->q_next_usable_key >= MERKLE_TREE_LEAF_NODE_AM(ctx->params.type)) {
728         return MBEDTLS_ERR_LMS_OUT_OF_PRIVATE_KEYS;
729     }
730 
731 
732     q_leaf_identifier = ctx->q_next_usable_key;
733     /* This new value must _always_ be written back to the disk before the
734      * signature is returned.
735      */
736     ctx->q_next_usable_key += 1;
737 
738     if (MBEDTLS_LMS_SIG_LEN(ctx->params.type, ctx->params.otstype)
739         < SIG_OTS_SIG_OFFSET) {
740         return MBEDTLS_ERR_LMS_BAD_INPUT_DATA;
741     }
742 
743     ret = mbedtls_lmots_sign(&ctx->ots_private_keys[q_leaf_identifier],
744                              f_rng,
745                              p_rng,
746                              msg,
747                              msg_size,
748                              sig + SIG_OTS_SIG_OFFSET,
749                              MBEDTLS_LMS_SIG_LEN(ctx->params.type,
750                                                  ctx->params.otstype) - SIG_OTS_SIG_OFFSET,
751                              NULL);
752     if (ret != 0) {
753         return ret;
754     }
755 
756     mbedtls_lms_unsigned_int_to_network_bytes(ctx->params.type,
757                                               MBEDTLS_LMS_TYPE_LEN,
758                                               sig + SIG_TYPE_OFFSET(ctx->params.otstype));
759     mbedtls_lms_unsigned_int_to_network_bytes(q_leaf_identifier,
760                                               MBEDTLS_LMOTS_Q_LEAF_ID_LEN,
761                                               sig + SIG_Q_LEAF_ID_OFFSET);
762 
763     ret = get_merkle_path(ctx,
764                           MERKLE_TREE_INTERNAL_NODE_AM(ctx->params.type) + q_leaf_identifier,
765                           sig + SIG_PATH_OFFSET(ctx->params.otstype));
766     if (ret != 0) {
767         return ret;
768     }
769 
770     if (sig_len != NULL) {
771         *sig_len = MBEDTLS_LMS_SIG_LEN(ctx->params.type, ctx->params.otstype);
772     }
773 
774 
775     return 0;
776 }
777 
778 #endif /* defined(MBEDTLS_LMS_PRIVATE) */
779 #endif /* defined(MBEDTLS_LMS_C) */
780