1"""Common features for bignum in test generation framework."""
2# Copyright The Mbed TLS Contributors
3# SPDX-License-Identifier: Apache-2.0
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may
6# not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17from abc import abstractmethod
18from typing import Iterator, List, Tuple, TypeVar, Any
19from itertools import chain
20
21from . import test_case
22from . import test_data_generation
23from .bignum_data import INPUTS_DEFAULT, MODULI_DEFAULT
24
25T = TypeVar('T') #pylint: disable=invalid-name
26
27def invmod(a: int, n: int) -> int:
28    """Return inverse of a to modulo n.
29
30    Equivalent to pow(a, -1, n) in Python 3.8+. Implementation is equivalent
31    to long_invmod() in CPython.
32    """
33    b, c = 1, 0
34    while n:
35        q, r = divmod(a, n)
36        a, b, c, n = n, c, b - q*c, r
37    # at this point a is the gcd of the original inputs
38    if a == 1:
39        return b
40    raise ValueError("Not invertible")
41
42def hex_to_int(val: str) -> int:
43    """Implement the syntax accepted by mbedtls_test_read_mpi().
44
45    This is a superset of what is accepted by mbedtls_test_read_mpi_core().
46    """
47    if val in ['', '-']:
48        return 0
49    return int(val, 16)
50
51def quote_str(val) -> str:
52    return "\"{}\"".format(val)
53
54def bound_mpi(val: int, bits_in_limb: int) -> int:
55    """First number exceeding number of limbs needed for given input value."""
56    return bound_mpi_limbs(limbs_mpi(val, bits_in_limb), bits_in_limb)
57
58def bound_mpi_limbs(limbs: int, bits_in_limb: int) -> int:
59    """First number exceeding maximum of given number of limbs."""
60    bits = bits_in_limb * limbs
61    return 1 << bits
62
63def limbs_mpi(val: int, bits_in_limb: int) -> int:
64    """Return the number of limbs required to store value."""
65    return (val.bit_length() + bits_in_limb - 1) // bits_in_limb
66
67def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
68    """Return all pair combinations from input values."""
69    return [(x, y) for x in values for y in values]
70
71class OperationCommon(test_data_generation.BaseTest):
72    """Common features for bignum binary operations.
73
74    This adds functionality common in binary operation tests.
75
76    Attributes:
77        symbol: Symbol to use for the operation in case description.
78        input_values: List of values to use as test case inputs. These are
79            combined to produce pairs of values.
80        input_cases: List of tuples containing pairs of test case inputs. This
81            can be used to implement specific pairs of inputs.
82        unique_combinations_only: Boolean to select if test case combinations
83            must be unique. If True, only A,B or B,A would be included as a test
84            case. If False, both A,B and B,A would be included.
85        input_style: Controls the way how test data is passed to the functions
86            in the generated test cases. "variable" passes them as they are
87            defined in the python source. "arch_split" pads the values with
88            zeroes depending on the architecture/limb size. If this is set,
89            test cases are generated for all architectures.
90        arity: the number of operands for the operation. Currently supported
91            values are 1 and 2.
92    """
93    symbol = ""
94    input_values = INPUTS_DEFAULT # type: List[str]
95    input_cases = [] # type: List[Any]
96    unique_combinations_only = False
97    input_styles = ["variable", "fixed", "arch_split"] # type: List[str]
98    input_style = "variable" # type: str
99    limb_sizes = [32, 64] # type: List[int]
100    arities = [1, 2]
101    arity = 2
102
103    def __init__(self, val_a: str, val_b: str = "0", bits_in_limb: int = 32) -> None:
104        self.val_a = val_a
105        self.val_b = val_b
106        # Setting the int versions here as opposed to making them @properties
107        # provides earlier/more robust input validation.
108        self.int_a = hex_to_int(val_a)
109        self.int_b = hex_to_int(val_b)
110        if bits_in_limb not in self.limb_sizes:
111            raise ValueError("Invalid number of bits in limb!")
112        if self.input_style == "arch_split":
113            self.dependencies = ["MBEDTLS_HAVE_INT{:d}".format(bits_in_limb)]
114        self.bits_in_limb = bits_in_limb
115
116    @property
117    def boundary(self) -> int:
118        if self.arity == 1:
119            return self.int_a
120        elif self.arity == 2:
121            return max(self.int_a, self.int_b)
122        raise ValueError("Unsupported number of operands!")
123
124    @property
125    def limb_boundary(self) -> int:
126        return bound_mpi(self.boundary, self.bits_in_limb)
127
128    @property
129    def limbs(self) -> int:
130        return limbs_mpi(self.boundary, self.bits_in_limb)
131
132    @property
133    def hex_digits(self) -> int:
134        return 2 * (self.limbs * self.bits_in_limb // 8)
135
136    def format_arg(self, val) -> str:
137        if self.input_style not in self.input_styles:
138            raise ValueError("Unknown input style!")
139        if self.input_style == "variable":
140            return val
141        else:
142            return val.zfill(self.hex_digits)
143
144    def format_result(self, res) -> str:
145        res_str = '{:x}'.format(res)
146        return quote_str(self.format_arg(res_str))
147
148    @property
149    def arg_a(self) -> str:
150        return self.format_arg(self.val_a)
151
152    @property
153    def arg_b(self) -> str:
154        if self.arity == 1:
155            raise AttributeError("Operation is unary and doesn't have arg_b!")
156        return self.format_arg(self.val_b)
157
158    def arguments(self) -> List[str]:
159        args = [quote_str(self.arg_a)]
160        if self.arity == 2:
161            args.append(quote_str(self.arg_b))
162        return args + self.result()
163
164    def description(self) -> str:
165        """Generate a description for the test case.
166
167        If not set, case_description uses the form A `symbol` B, where symbol
168        is used to represent the operation. Descriptions of each value are
169        generated to provide some context to the test case.
170        """
171        if not self.case_description:
172            if self.arity == 1:
173                self.case_description = "{} {:x}".format(
174                    self.symbol, self.int_a
175                )
176            elif self.arity == 2:
177                self.case_description = "{:x} {} {:x}".format(
178                    self.int_a, self.symbol, self.int_b
179                )
180        return super().description()
181
182    @property
183    def is_valid(self) -> bool:
184        return True
185
186    @abstractmethod
187    def result(self) -> List[str]:
188        """Get the result of the operation.
189
190        This could be calculated during initialization and stored as `_result`
191        and then returned, or calculated when the method is called.
192        """
193        raise NotImplementedError
194
195    @classmethod
196    def get_value_pairs(cls) -> Iterator[Tuple[str, str]]:
197        """Generator to yield pairs of inputs.
198
199        Combinations are first generated from all input values, and then
200        specific cases provided.
201        """
202        if cls.arity == 1:
203            yield from ((a, "0") for a in cls.input_values)
204        elif cls.arity == 2:
205            if cls.unique_combinations_only:
206                yield from combination_pairs(cls.input_values)
207            else:
208                yield from (
209                    (a, b)
210                    for a in cls.input_values
211                    for b in cls.input_values
212                )
213        else:
214            raise ValueError("Unsupported number of operands!")
215
216    @classmethod
217    def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
218        if cls.input_style not in cls.input_styles:
219            raise ValueError("Unknown input style!")
220        if cls.arity not in cls.arities:
221            raise ValueError("Unsupported number of operands!")
222        if cls.input_style == "arch_split":
223            test_objects = (cls(a, b, bits_in_limb=bil)
224                            for a, b in cls.get_value_pairs()
225                            for bil in cls.limb_sizes)
226            special_cases = (cls(*args, bits_in_limb=bil) # type: ignore
227                             for args in cls.input_cases
228                             for bil in cls.limb_sizes)
229        else:
230            test_objects = (cls(a, b)
231                            for a, b in cls.get_value_pairs())
232            special_cases = (cls(*args) for args in cls.input_cases)
233        yield from (valid_test_object.create_test_case()
234                    for valid_test_object in filter(
235                        lambda test_object: test_object.is_valid,
236                        chain(test_objects, special_cases)
237                        )
238                    )
239
240
241class ModOperationCommon(OperationCommon):
242    #pylint: disable=abstract-method
243    """Target for bignum mod_raw test case generation."""
244    moduli = MODULI_DEFAULT # type: List[str]
245
246    def __init__(self, val_n: str, val_a: str, val_b: str = "0",
247                 bits_in_limb: int = 64) -> None:
248        super().__init__(val_a=val_a, val_b=val_b, bits_in_limb=bits_in_limb)
249        self.val_n = val_n
250        # Setting the int versions here as opposed to making them @properties
251        # provides earlier/more robust input validation.
252        self.int_n = hex_to_int(val_n)
253
254    def to_montgomery(self, val: int) -> int:
255        return (val * self.r) % self.int_n
256
257    def from_montgomery(self, val: int) -> int:
258        return (val * self.r_inv) % self.int_n
259
260    @property
261    def boundary(self) -> int:
262        return self.int_n
263
264    @property
265    def arg_n(self) -> str:
266        return self.format_arg(self.val_n)
267
268    def arguments(self) -> List[str]:
269        return [quote_str(self.arg_n)] + super().arguments()
270
271    @property
272    def r(self) -> int: # pylint: disable=invalid-name
273        l = limbs_mpi(self.int_n, self.bits_in_limb)
274        return bound_mpi_limbs(l, self.bits_in_limb)
275
276    @property
277    def r_inv(self) -> int:
278        return invmod(self.r, self.int_n)
279
280    @property
281    def r2(self) -> int: # pylint: disable=invalid-name
282        return pow(self.r, 2)
283
284    @property
285    def is_valid(self) -> bool:
286        if self.int_a >= self.int_n:
287            return False
288        if self.arity == 2 and self.int_b >= self.int_n:
289            return False
290        return True
291
292    def description(self) -> str:
293        """Generate a description for the test case.
294
295        It uses the form A `symbol` B mod N, where symbol is used to represent
296        the operation.
297        """
298
299        if not self.case_description:
300            return super().description() + " mod {:x}".format(self.int_n)
301        return super().description()
302
303    @classmethod
304    def input_cases_args(cls) -> Iterator[Tuple[Any, Any, Any]]:
305        if cls.arity == 1:
306            yield from ((n, a, "0") for a, n in cls.input_cases)
307        elif cls.arity == 2:
308            yield from ((n, a, b) for a, b, n in cls.input_cases)
309        else:
310            raise ValueError("Unsupported number of operands!")
311
312    @classmethod
313    def generate_function_tests(cls) -> Iterator[test_case.TestCase]:
314        if cls.input_style not in cls.input_styles:
315            raise ValueError("Unknown input style!")
316        if cls.arity not in cls.arities:
317            raise ValueError("Unsupported number of operands!")
318        if cls.input_style == "arch_split":
319            test_objects = (cls(n, a, b, bits_in_limb=bil)
320                            for n in cls.moduli
321                            for a, b in cls.get_value_pairs()
322                            for bil in cls.limb_sizes)
323            special_cases = (cls(*args, bits_in_limb=bil)
324                             for args in cls.input_cases_args()
325                             for bil in cls.limb_sizes)
326        else:
327            test_objects = (cls(n, a, b)
328                            for n in cls.moduli
329                            for a, b in cls.get_value_pairs())
330            special_cases = (cls(*args) for args in cls.input_cases_args())
331        yield from (valid_test_object.create_test_case()
332                    for valid_test_object in filter(
333                        lambda test_object: test_object.is_valid,
334                        chain(test_objects, special_cases)
335                        ))
336
337# BEGIN MERGE SLOT 1
338
339# END MERGE SLOT 1
340
341# BEGIN MERGE SLOT 2
342
343# END MERGE SLOT 2
344
345# BEGIN MERGE SLOT 3
346
347# END MERGE SLOT 3
348
349# BEGIN MERGE SLOT 4
350
351# END MERGE SLOT 4
352
353# BEGIN MERGE SLOT 5
354
355# END MERGE SLOT 5
356
357# BEGIN MERGE SLOT 6
358
359# END MERGE SLOT 6
360
361# BEGIN MERGE SLOT 7
362
363# END MERGE SLOT 7
364
365# BEGIN MERGE SLOT 8
366
367# END MERGE SLOT 8
368
369# BEGIN MERGE SLOT 9
370
371# END MERGE SLOT 9
372
373# BEGIN MERGE SLOT 10
374
375# END MERGE SLOT 10
376