1#!/usr/bin/env python3
2
3# Copyright 2023 Google LLC
4# SPDX-License-Identifier: Apache-2.0
5"""
6Tests for check_init_priorities
7"""
8
9import pathlib
10import unittest
11from unittest import mock
12
13import check_init_priorities
14from elftools.elf.relocation import Section
15from elftools.elf.sections import SymbolTableSection
16
17
18class TestPriority(unittest.TestCase):
19    """Tests for the Priority class."""
20
21    def test_priority_parsing(self):
22        prio1 = check_init_priorities.Priority("POST_KERNEL", 12)
23        self.assertEqual(prio1._level_priority, (3, 12))
24
25        prio1 = check_init_priorities.Priority("APPLICATION", 9999)
26        self.assertEqual(prio1._level_priority, (4, 9999))
27
28        with self.assertRaises(ValueError):
29            check_init_priorities.Priority("i-am-not-a-priority", 0)
30            check_init_priorities.Priority("_DOESNOTEXIST0_", 0)
31
32    def test_priority_levels(self):
33        prios = [
34            check_init_priorities.Priority("EARLY", 0),
35            check_init_priorities.Priority("EARLY", 1),
36            check_init_priorities.Priority("PRE_KERNEL_1", 0),
37            check_init_priorities.Priority("PRE_KERNEL_1", 1),
38            check_init_priorities.Priority("PRE_KERNEL_2", 0),
39            check_init_priorities.Priority("PRE_KERNEL_2", 1),
40            check_init_priorities.Priority("POST_KERNEL", 0),
41            check_init_priorities.Priority("POST_KERNEL", 1),
42            check_init_priorities.Priority("APPLICATION", 0),
43            check_init_priorities.Priority("APPLICATION", 1),
44            check_init_priorities.Priority("SMP", 0),
45            check_init_priorities.Priority("SMP", 1),
46        ]
47
48        self.assertListEqual(prios, sorted(prios))
49
50    def test_priority_strings(self):
51        prio = check_init_priorities.Priority("POST_KERNEL", 12)
52        self.assertEqual(str(prio), "POST_KERNEL+12")
53        self.assertEqual(repr(prio), "<Priority POST_KERNEL 12>")
54
55
56class testZephyrInitLevels(unittest.TestCase):
57    """Tests for the ZephyrInitLevels class."""
58
59    @mock.patch("check_init_priorities.ZephyrInitLevels.__init__", return_value=None)
60    def test_load_objects(self, mock_zilinit):
61        mock_elf = mock.Mock()
62
63        sts = mock.Mock(spec=SymbolTableSection)
64        rel = mock.Mock(spec=Section)
65        mock_elf.iter_sections.return_value = [sts, rel]
66
67        s0 = mock.Mock()
68        s0.name = "a"
69        s0.entry.st_info.type = "STT_OBJECT"
70        s0.entry.st_size = 4
71        s0.entry.st_value = 0xAA
72        s0.entry.st_shndx = 1
73
74        s1 = mock.Mock()
75        s1.name = None
76
77        s2 = mock.Mock()
78        s2.name = "b"
79        s2.entry.st_info.type = "STT_FUNC"
80        s2.entry.st_size = 8
81        s2.entry.st_value = 0xBB
82        s2.entry.st_shndx = 2
83
84        sts.iter_symbols.return_value = [s0, s1, s2]
85
86        obj = check_init_priorities.ZephyrInitLevels("", None)
87        obj._elf = mock_elf
88        obj._load_objects()
89
90        self.assertDictEqual(obj._objects, {0xAA: ("a", 4, 1), 0xBB: ("b", 8, 2)})
91
92    @mock.patch("check_init_priorities.ZephyrInitLevels.__init__", return_value=None)
93    def test_load_level_addr(self, mock_zilinit):
94        mock_elf = mock.Mock()
95
96        sts = mock.Mock(spec=SymbolTableSection)
97        rel = mock.Mock(spec=Section)
98        mock_elf.iter_sections.return_value = [sts, rel]
99
100        s0 = mock.Mock()
101        s0.name = "__init_EARLY_start"
102        s0.entry.st_value = 0x00
103
104        s1 = mock.Mock()
105        s1.name = "__init_PRE_KERNEL_1_start"
106        s1.entry.st_value = 0x11
107
108        s2 = mock.Mock()
109        s2.name = "__init_PRE_KERNEL_2_start"
110        s2.entry.st_value = 0x22
111
112        s3 = mock.Mock()
113        s3.name = "__init_POST_KERNEL_start"
114        s3.entry.st_value = 0x33
115
116        s4 = mock.Mock()
117        s4.name = "__init_APPLICATION_start"
118        s4.entry.st_value = 0x44
119
120        s5 = mock.Mock()
121        s5.name = "__init_SMP_start"
122        s5.entry.st_value = 0x55
123
124        s6 = mock.Mock()
125        s6.name = "__init_end"
126        s6.entry.st_value = 0x66
127
128        sts.iter_symbols.return_value = [s0, s1, s2, s3, s4, s5, s6]
129
130        obj = check_init_priorities.ZephyrInitLevels("", None)
131        obj._elf = mock_elf
132        obj._load_level_addr()
133
134        self.assertDictEqual(
135            obj._init_level_addr,
136            {
137                "EARLY": 0x00,
138                "PRE_KERNEL_1": 0x11,
139                "PRE_KERNEL_2": 0x22,
140                "POST_KERNEL": 0x33,
141                "APPLICATION": 0x44,
142                "SMP": 0x55,
143            },
144        )
145        self.assertEqual(obj._init_level_end, 0x66)
146
147    @mock.patch("check_init_priorities.ZephyrInitLevels.__init__", return_value=None)
148    def test_device_ord_from_name(self, mock_zilinit):
149        obj = check_init_priorities.ZephyrInitLevels("", None)
150
151        self.assertEqual(obj._device_ord_from_name(None), None)
152        self.assertEqual(obj._device_ord_from_name("hey, hi!"), None)
153        self.assertEqual(obj._device_ord_from_name("__device_dts_ord_123"), 123)
154
155    @mock.patch("check_init_priorities.ZephyrInitLevels.__init__", return_value=None)
156    def test_object_name(self, mock_zilinit):
157        obj = check_init_priorities.ZephyrInitLevels("", None)
158        obj._objects = {0x123: ("name", 4)}
159
160        self.assertEqual(obj._object_name(0), "NULL")
161        self.assertEqual(obj._object_name(73), "unknown")
162        self.assertEqual(obj._object_name(0x123), "name")
163
164    @mock.patch("check_init_priorities.ZephyrInitLevels.__init__", return_value=None)
165    def test_initlevel_pointer_32(self, mock_zilinit):
166        obj = check_init_priorities.ZephyrInitLevels("", None)
167        obj._elf = mock.Mock()
168        obj._elf.elfclass = 32
169        mock_section = mock.Mock()
170        obj._elf.get_section.return_value = mock_section
171        mock_section.header.sh_addr = 0x100
172        mock_section.data.return_value = b"\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00"
173
174        self.assertEqual(obj._initlevel_pointer(0x100, 0, 0), 1)
175        self.assertEqual(obj._initlevel_pointer(0x100, 1, 0), 2)
176        self.assertEqual(obj._initlevel_pointer(0x104, 0, 0), 2)
177        self.assertEqual(obj._initlevel_pointer(0x104, 1, 0), 3)
178
179    @mock.patch("check_init_priorities.ZephyrInitLevels.__init__", return_value=None)
180    def test_initlevel_pointer_64(self, mock_zilinit):
181        obj = check_init_priorities.ZephyrInitLevels("", None)
182        obj._elf = mock.Mock()
183        obj._elf.elfclass = 64
184        mock_section = mock.Mock()
185        obj._elf.get_section.return_value = mock_section
186        mock_section.header.sh_addr = 0x100
187        mock_section.data.return_value = (
188            b"\x01\x00\x00\x00\x00\x00\x00\x00"
189            b"\x02\x00\x00\x00\x00\x00\x00\x00"
190            b"\x03\x00\x00\x00\x00\x00\x00\x00"
191        )
192
193        self.assertEqual(obj._initlevel_pointer(0x100, 0, 0), 1)
194        self.assertEqual(obj._initlevel_pointer(0x100, 1, 0), 2)
195        self.assertEqual(obj._initlevel_pointer(0x108, 0, 0), 2)
196        self.assertEqual(obj._initlevel_pointer(0x108, 1, 0), 3)
197
198    @mock.patch("check_init_priorities.ZephyrInitLevels._object_name")
199    @mock.patch("check_init_priorities.ZephyrInitLevels._initlevel_pointer")
200    @mock.patch("check_init_priorities.ZephyrInitLevels.__init__", return_value=None)
201    def test_process_initlevels(self, mock_zilinit, mock_ip, mock_on):
202        obj = check_init_priorities.ZephyrInitLevels("", None)
203        obj._init_level_addr = {
204            "EARLY": 0x00,
205            "PRE_KERNEL_1": 0x00,
206            "PRE_KERNEL_2": 0x00,
207            "POST_KERNEL": 0x08,
208            "APPLICATION": 0x0C,
209            "SMP": 0x0C,
210        }
211        obj._init_level_end = 0x0C
212        obj._objects = {
213            0x00: ("a", 4, 0),
214            0x04: ("b", 4, 0),
215            0x08: ("c", 4, 0),
216        }
217        obj._object_addr = {
218            "__device_dts_ord_11": 0x00,
219            "__device_dts_ord_22": 0x04,
220        }
221
222        mock_ip.side_effect = lambda *args: args
223
224        def mock_obj_name(*args):
225            if args[0] == (0, 5, 0):
226                return "i0"
227            elif args[0] == (0, 1, 0):
228                return "__device_dts_ord_11"
229            elif args[0] == (4, 5, 0):
230                return "i1"
231            elif args[0] == (4, 1, 0):
232                return "__device_dts_ord_22"
233            return f"name_{args[0][0]}_{args[0][1]}"
234
235        mock_on.side_effect = mock_obj_name
236
237        obj._process_initlevels()
238
239        self.assertDictEqual(
240            obj.initlevels,
241            {
242                "EARLY": [],
243                "PRE_KERNEL_1": [],
244                "PRE_KERNEL_2": ["a: i0(__device_dts_ord_11)", "b: i1(__device_dts_ord_22)"],
245                "POST_KERNEL": ["c: name_8_0(name_8_1)"],
246                "APPLICATION": [],
247                "SMP": [],
248            },
249        )
250        self.assertDictEqual(
251            obj.devices,
252            {
253                11: (check_init_priorities.Priority("PRE_KERNEL_2", 0), "i0"),
254                22: (check_init_priorities.Priority("PRE_KERNEL_2", 1), "i1"),
255            },
256        )
257
258
259class testValidator(unittest.TestCase):
260    """Tests for the Validator class."""
261
262    @mock.patch("check_init_priorities.ZephyrInitLevels")
263    @mock.patch("pickle.load")
264    def test_initialize(self, mock_pl, mock_zil):
265        mock_log = mock.Mock()
266        mock_prio = mock.Mock()
267        mock_obj = mock.Mock()
268        mock_obj.defined_devices = {123: mock_prio}
269        mock_zil.return_value = mock_obj
270
271        with mock.patch("builtins.open", mock.mock_open()) as mock_open:
272            validator = check_init_priorities.Validator("path", "pickle", mock_log, None)
273
274        self.assertEqual(validator._obj, mock_obj)
275        mock_zil.assert_called_once_with("path", None)
276        mock_open.assert_called_once_with(pathlib.Path("pickle"), "rb")
277
278    @mock.patch("check_init_priorities.Validator.__init__", return_value=None)
279    def test_check_dep_same_node(self, mock_vinit):
280        validator = check_init_priorities.Validator("", "", None, None)
281        validator.log = mock.Mock()
282
283        validator._check_dep(123, 123)
284
285        self.assertFalse(validator.log.info.called)
286        self.assertFalse(validator.log.warning.called)
287        self.assertFalse(validator.log.error.called)
288
289    @mock.patch("check_init_priorities.Validator.__init__", return_value=None)
290    def test_check_dep_no_prio(self, mock_vinit):
291        validator = check_init_priorities.Validator("", "", None, None)
292        validator.log = mock.Mock()
293        validator._obj = mock.Mock()
294
295        validator._ord2node = {1: mock.Mock(), 2: mock.Mock()}
296        validator._ord2node[1]._binding = None
297        validator._ord2node[1].props = {}
298        validator._ord2node[2]._binding = None
299        validator._ord2node[2].props = {}
300
301        validator._obj.devices = {1: (10, "i1")}
302        validator._check_dep(1, 2)
303
304        validator._obj.devices = {2: (20, "i2")}
305        validator._check_dep(1, 2)
306
307        self.assertFalse(validator.log.info.called)
308        self.assertFalse(validator.log.warning.called)
309        self.assertFalse(validator.log.error.called)
310
311    @mock.patch("check_init_priorities.Validator.__init__", return_value=None)
312    def test_check(self, mock_vinit):
313        validator = check_init_priorities.Validator("", "", None, None)
314        validator.log = mock.Mock()
315        validator._obj = mock.Mock()
316        validator.errors = 0
317
318        validator._ord2node = {1: mock.Mock(), 2: mock.Mock()}
319        validator._ord2node[1]._binding = None
320        validator._ord2node[1].path = "/1"
321        validator._ord2node[1].props = {}
322        validator._ord2node[2]._binding = None
323        validator._ord2node[2].path = "/2"
324        validator._ord2node[2].props = {}
325
326        validator._obj.devices = {1: (10, "i1"), 2: (20, "i2")}
327
328        validator._check_dep(2, 1)
329        validator._check_dep(1, 2)
330
331        validator.log.info.assert_called_once_with("/2 <i2> 20 > /1 <i1> 10")
332        validator.log.error.assert_has_calls(
333            [mock.call("/1 <i1> is initialized before its dependency /2 <i2> (10 < 20)")]
334        )
335        self.assertEqual(validator.errors, 1)
336
337    @mock.patch("check_init_priorities.Validator.__init__", return_value=None)
338    def test_check_same_prio_assert(self, mock_vinit):
339        validator = check_init_priorities.Validator("", "", None, None)
340        validator.log = mock.Mock()
341        validator._obj = mock.Mock()
342        validator.errors = 0
343
344        validator._ord2node = {1: mock.Mock(), 2: mock.Mock()}
345        validator._ord2node[1]._binding = None
346        validator._ord2node[1].path = "/1"
347        validator._ord2node[1].props = {}
348        validator._ord2node[2]._binding = None
349        validator._ord2node[2].path = "/2"
350        validator._ord2node[2].props = {}
351
352        validator._obj.devices = {1: (10, "i1"), 2: (10, "i2")}
353
354        with self.assertRaises(ValueError):
355            validator._check_dep(1, 2)
356
357    @mock.patch("check_init_priorities.Validator.__init__", return_value=None)
358    def test_check_ignored(self, mock_vinit):
359        validator = check_init_priorities.Validator("", "", None, None)
360        validator.log = mock.Mock()
361        validator._obj = mock.Mock()
362        validator.errors = 0
363
364        save_ignore_compatibles = check_init_priorities._IGNORE_COMPATIBLES
365
366        check_init_priorities._IGNORE_COMPATIBLES = set(["compat-3"])
367
368        validator._ord2node = {1: mock.Mock(), 3: mock.Mock()}
369        validator._ord2node[1]._binding.compatible = "compat-1"
370        validator._ord2node[1].path = "/1"
371        validator._ord2node[3]._binding.compatible = "compat-3"
372        validator._ord2node[3].path = "/3"
373
374        validator._obj.devices = {1: 20, 3: 10}
375
376        validator._check_dep(3, 1)
377
378        self.assertListEqual(
379            validator.log.info.call_args_list,
380            [
381                mock.call("Ignoring priority: compat-3"),
382            ],
383        )
384        self.assertEqual(validator.errors, 0)
385
386        check_init_priorities._IGNORE_COMPATIBLES = save_ignore_compatibles
387
388    @mock.patch("check_init_priorities.Validator.__init__", return_value=None)
389    def test_check_deferred(self, mock_vinit):
390        validator = check_init_priorities.Validator("", "", None, None)
391        validator.log = mock.Mock()
392        validator._obj = mock.Mock()
393        validator.errors = 0
394
395        validator._ord2node = {1: mock.Mock(), 2: mock.Mock()}
396        validator._ord2node[1]._binding = None
397        validator._ord2node[1].path = "/1"
398        validator._ord2node[1].props = {}
399        validator._ord2node[2]._binding = None
400        validator._ord2node[2].path = "/2"
401        validator._ord2node[2].props = {
402            check_init_priorities._DEFERRED_INIT_PROP_NAME: mock.Mock(val=True)
403        }
404
405        validator._obj.devices = {1: (10, "i1"), 2: (5, "i2")}
406
407        validator._check_dep(2, 1)
408        validator._check_dep(1, 2)
409
410        validator.log.info.assert_called_once_with("Ignoring deferred device /2")
411        validator.log.error.assert_has_calls(
412            [mock.call("Non-deferred device /1 depends on deferred device /2")]
413        )
414        self.assertEqual(validator.errors, 1)
415
416    @mock.patch("check_init_priorities.Validator._check_dep")
417    @mock.patch("check_init_priorities.Validator.__init__", return_value=None)
418    def test_check_edt(self, mock_vinit, mock_cd):
419        d0 = mock.Mock()
420        d0.dep_ordinal = 1
421        d1 = mock.Mock()
422        d1.dep_ordinal = 2
423        d2 = mock.Mock()
424        d2.dep_ordinal = 3
425
426        dev0 = mock.Mock()
427        dev0.depends_on = [d0]
428        dev1 = mock.Mock()
429        dev1.depends_on = [d1]
430        dev2 = mock.Mock()
431        dev2.depends_on = [d2]
432
433        validator = check_init_priorities.Validator("", "", None, None)
434        validator._ord2node = {1: dev0, 2: dev1, 3: dev2}
435        validator._obj = mock.Mock()
436        validator._obj.devices = {1: 10, 2: 10, 3: 20}
437
438        validator.check_edt()
439
440        self.assertListEqual(
441            mock_cd.call_args_list,
442            [
443                mock.call(1, 1),
444                mock.call(2, 2),
445                mock.call(3, 3),
446            ],
447        )
448
449
450if __name__ == "__main__":
451    unittest.main()
452