1#!/usr/bin/python
2# SPDX-License-Identifier: GPL-2.0
3
4import subprocess
5import json as j
6import random
7
8
9class SkipTest(Exception):
10    pass
11
12
13class RandomValuePicker:
14    """
15    Class for storing shared buffer configuration. Can handle 3 different
16    objects, pool, tcbind and portpool. Provide an interface to get random
17    values for a specific object type as the follow:
18      1. Pool:
19         - random size
20
21      2. TcBind:
22         - random pool number
23         - random threshold
24
25      3. PortPool:
26         - random threshold
27    """
28    def __init__(self, pools):
29        self._pools = []
30        for pool in pools:
31            self._pools.append(pool)
32
33    def _cell_size(self):
34        return self._pools[0]["cell_size"]
35
36    def _get_static_size(self, th):
37        # For threshold of 16, this works out to be about 12MB on Spectrum-1,
38        # and about 17MB on Spectrum-2.
39        return th * 8000 * self._cell_size()
40
41    def _get_size(self):
42        return self._get_static_size(16)
43
44    def _get_thtype(self):
45        return "static"
46
47    def _get_th(self, pool):
48        # Threshold value could be any integer between 3 to 16
49        th = random.randint(3, 16)
50        if pool["thtype"] == "dynamic":
51            return th
52        else:
53            return self._get_static_size(th)
54
55    def _get_pool(self, direction):
56        ing_pools = []
57        egr_pools = []
58        for pool in self._pools:
59            if pool["type"] == "ingress":
60                ing_pools.append(pool)
61            else:
62                egr_pools.append(pool)
63        if direction == "ingress":
64            arr = ing_pools
65        else:
66            arr = egr_pools
67        return arr[random.randint(0, len(arr) - 1)]
68
69    def get_value(self, objid):
70        if isinstance(objid, Pool):
71            if objid["pool"] in [4, 8, 9, 10]:
72                # The threshold type of pools 4, 8, 9 and 10 cannot be changed
73                raise SkipTest()
74            else:
75                return (self._get_size(), self._get_thtype())
76        if isinstance(objid, TcBind):
77            if objid["tc"] >= 8:
78                # Multicast TCs cannot be changed
79                raise SkipTest()
80            else:
81                pool = self._get_pool(objid["type"])
82                th = self._get_th(pool)
83                pool_n = pool["pool"]
84                return (pool_n, th)
85        if isinstance(objid, PortPool):
86            pool_n = objid["pool"]
87            pool = self._pools[pool_n]
88            assert pool["pool"] == pool_n
89            th = self._get_th(pool)
90            return (th,)
91
92
93class RecordValuePickerException(Exception):
94    pass
95
96
97class RecordValuePicker:
98    """
99    Class for storing shared buffer configuration. Can handle 2 different
100    objects, pool and tcbind. Provide an interface to get the stored values per
101    object type.
102    """
103    def __init__(self, objlist):
104        self._recs = []
105        for item in objlist:
106            self._recs.append({"objid": item, "value": item.var_tuple()})
107
108    def get_value(self, objid):
109        if isinstance(objid, Pool) and objid["pool"] in [4, 8, 9, 10]:
110            # The threshold type of pools 4, 8, 9 and 10 cannot be changed
111            raise SkipTest()
112        if isinstance(objid, TcBind) and objid["tc"] >= 8:
113            # Multicast TCs cannot be changed
114            raise SkipTest()
115        for rec in self._recs:
116            if rec["objid"].weak_eq(objid):
117                return rec["value"]
118        raise RecordValuePickerException()
119
120
121def run_cmd(cmd, json=False):
122    out = subprocess.check_output(cmd, shell=True)
123    if json:
124        return j.loads(out)
125    return out
126
127
128def run_json_cmd(cmd):
129    return run_cmd(cmd, json=True)
130
131
132def log_test(test_name, err_msg=None):
133    if err_msg:
134        print("\t%s" % err_msg)
135        print("TEST: %-80s  [FAIL]" % test_name)
136    else:
137        print("TEST: %-80s  [ OK ]" % test_name)
138
139
140class CommonItem(dict):
141    varitems = []
142
143    def var_tuple(self):
144        ret = []
145        self.varitems.sort()
146        for key in self.varitems:
147            ret.append(self[key])
148        return tuple(ret)
149
150    def weak_eq(self, other):
151        for key in self:
152            if key in self.varitems:
153                continue
154            if self[key] != other[key]:
155                return False
156        return True
157
158
159class CommonList(list):
160    def get_by(self, by_obj):
161        for item in self:
162            if item.weak_eq(by_obj):
163                return item
164        return None
165
166    def del_by(self, by_obj):
167        for item in self:
168            if item.weak_eq(by_obj):
169                self.remove(item)
170
171
172class Pool(CommonItem):
173    varitems = ["size", "thtype"]
174
175    def dl_set(self, dlname, size, thtype):
176        run_cmd("devlink sb pool set {} sb {} pool {} size {} thtype {}".format(dlname, self["sb"],
177                                                                                self["pool"],
178                                                                                size, thtype))
179
180
181class PoolList(CommonList):
182    pass
183
184
185def get_pools(dlname, direction=None):
186    d = run_json_cmd("devlink sb pool show -j")
187    pools = PoolList()
188    for pooldict in d["pool"][dlname]:
189        if not direction or direction == pooldict["type"]:
190            pools.append(Pool(pooldict))
191    return pools
192
193
194def do_check_pools(dlname, pools, vp):
195    for pool in pools:
196        pre_pools = get_pools(dlname)
197        try:
198            (size, thtype) = vp.get_value(pool)
199        except SkipTest:
200            continue
201        pool.dl_set(dlname, size, thtype)
202        post_pools = get_pools(dlname)
203        pool = post_pools.get_by(pool)
204
205        err_msg = None
206        if pool["size"] != size:
207            err_msg = "Incorrect pool size (got {}, expected {})".format(pool["size"], size)
208        if pool["thtype"] != thtype:
209            err_msg = "Incorrect pool threshold type (got {}, expected {})".format(pool["thtype"], thtype)
210
211        pre_pools.del_by(pool)
212        post_pools.del_by(pool)
213        if pre_pools != post_pools:
214            err_msg = "Other pool setup changed as well"
215        log_test("pool {} of sb {} set verification".format(pool["pool"],
216                                                            pool["sb"]), err_msg)
217
218
219def check_pools(dlname, pools):
220    # Save defaults
221    record_vp = RecordValuePicker(pools)
222
223    # For each pool, set random size and static threshold type
224    do_check_pools(dlname, pools, RandomValuePicker(pools))
225
226    # Restore defaults
227    do_check_pools(dlname, pools, record_vp)
228
229
230class TcBind(CommonItem):
231    varitems = ["pool", "threshold"]
232
233    def __init__(self, port, d):
234        super(TcBind, self).__init__(d)
235        self["dlportname"] = port.name
236
237    def dl_set(self, pool, th):
238        run_cmd("devlink sb tc bind set {} sb {} tc {} type {} pool {} th {}".format(self["dlportname"],
239                                                                                     self["sb"],
240                                                                                     self["tc"],
241                                                                                     self["type"],
242                                                                                     pool, th))
243
244
245class TcBindList(CommonList):
246    pass
247
248
249def get_tcbinds(ports, verify_existence=False):
250    d = run_json_cmd("devlink sb tc bind show -j -n")
251    tcbinds = TcBindList()
252    for port in ports:
253        err_msg = None
254        if port.name not in d["tc_bind"] or len(d["tc_bind"][port.name]) == 0:
255            err_msg = "No tc bind for port"
256        else:
257            for tcbinddict in d["tc_bind"][port.name]:
258                tcbinds.append(TcBind(port, tcbinddict))
259        if verify_existence:
260            log_test("tc bind existence for port {} verification".format(port.name), err_msg)
261    return tcbinds
262
263
264def do_check_tcbind(ports, tcbinds, vp):
265    for tcbind in tcbinds:
266        pre_tcbinds = get_tcbinds(ports)
267        try:
268            (pool, th) = vp.get_value(tcbind)
269        except SkipTest:
270            continue
271        tcbind.dl_set(pool, th)
272        post_tcbinds = get_tcbinds(ports)
273        tcbind = post_tcbinds.get_by(tcbind)
274
275        err_msg = None
276        if tcbind["pool"] != pool:
277            err_msg = "Incorrect pool (got {}, expected {})".format(tcbind["pool"], pool)
278        if tcbind["threshold"] != th:
279            err_msg = "Incorrect threshold (got {}, expected {})".format(tcbind["threshold"], th)
280
281        pre_tcbinds.del_by(tcbind)
282        post_tcbinds.del_by(tcbind)
283        if pre_tcbinds != post_tcbinds:
284            err_msg = "Other tc bind setup changed as well"
285        log_test("tc bind {}-{} of sb {} set verification".format(tcbind["dlportname"],
286                                                                  tcbind["tc"],
287                                                                  tcbind["sb"]), err_msg)
288
289
290def check_tcbind(dlname, ports, pools):
291    tcbinds = get_tcbinds(ports, verify_existence=True)
292
293    # Save defaults
294    record_vp = RecordValuePicker(tcbinds)
295
296    # Bind each port and unicast TC (TCs < 8) to a random pool and a random
297    # threshold
298    do_check_tcbind(ports, tcbinds, RandomValuePicker(pools))
299
300    # Restore defaults
301    do_check_tcbind(ports, tcbinds, record_vp)
302
303
304class PortPool(CommonItem):
305    varitems = ["threshold"]
306
307    def __init__(self, port, d):
308        super(PortPool, self).__init__(d)
309        self["dlportname"] = port.name
310
311    def dl_set(self, th):
312        run_cmd("devlink sb port pool set {} sb {} pool {} th {}".format(self["dlportname"],
313                                                                         self["sb"],
314                                                                         self["pool"], th))
315
316
317class PortPoolList(CommonList):
318    pass
319
320
321def get_portpools(ports, verify_existence=False):
322    d = run_json_cmd("devlink sb port pool -j -n")
323    portpools = PortPoolList()
324    for port in ports:
325        err_msg = None
326        if port.name not in d["port_pool"] or len(d["port_pool"][port.name]) == 0:
327            err_msg = "No port pool for port"
328        else:
329            for portpooldict in d["port_pool"][port.name]:
330                portpools.append(PortPool(port, portpooldict))
331        if verify_existence:
332            log_test("port pool existence for port {} verification".format(port.name), err_msg)
333    return portpools
334
335
336def do_check_portpool(ports, portpools, vp):
337    for portpool in portpools:
338        pre_portpools = get_portpools(ports)
339        (th,) = vp.get_value(portpool)
340        portpool.dl_set(th)
341        post_portpools = get_portpools(ports)
342        portpool = post_portpools.get_by(portpool)
343
344        err_msg = None
345        if portpool["threshold"] != th:
346            err_msg = "Incorrect threshold (got {}, expected {})".format(portpool["threshold"], th)
347
348        pre_portpools.del_by(portpool)
349        post_portpools.del_by(portpool)
350        if pre_portpools != post_portpools:
351            err_msg = "Other port pool setup changed as well"
352        log_test("port pool {}-{} of sb {} set verification".format(portpool["dlportname"],
353                                                                    portpool["pool"],
354                                                                    portpool["sb"]), err_msg)
355
356
357def check_portpool(dlname, ports, pools):
358    portpools = get_portpools(ports, verify_existence=True)
359
360    # Save defaults
361    record_vp = RecordValuePicker(portpools)
362
363    # For each port pool, set a random threshold
364    do_check_portpool(ports, portpools, RandomValuePicker(pools))
365
366    # Restore defaults
367    do_check_portpool(ports, portpools, record_vp)
368
369
370class Port:
371    def __init__(self, name):
372        self.name = name
373
374
375class PortList(list):
376    pass
377
378
379def get_ports(dlname):
380    d = run_json_cmd("devlink port show -j")
381    ports = PortList()
382    for name in d["port"]:
383        if name.find(dlname) == 0 and d["port"][name]["flavour"] == "physical":
384            ports.append(Port(name))
385    return ports
386
387
388def get_device():
389    devices_info = run_json_cmd("devlink -j dev info")["info"]
390    for d in devices_info:
391        if "mlxsw_spectrum" in devices_info[d]["driver"]:
392            return d
393    return None
394
395
396class UnavailableDevlinkNameException(Exception):
397    pass
398
399
400def test_sb_configuration():
401    # Use static seed
402    random.seed(0)
403
404    dlname = get_device()
405    if not dlname:
406        raise UnavailableDevlinkNameException()
407
408    ports = get_ports(dlname)
409    pools = get_pools(dlname)
410
411    check_pools(dlname, pools)
412    check_tcbind(dlname, ports, pools)
413    check_portpool(dlname, ports, pools)
414
415
416test_sb_configuration()
417