1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the Multicast Listeners Table.
32  */
33 
34 #include "multicast_listeners_table.hpp"
35 
36 #if OPENTHREAD_FTD && OPENTHREAD_CONFIG_BACKBONE_ROUTER_MULTICAST_ROUTING_ENABLE
37 
38 #include "instance/instance.hpp"
39 
40 namespace ot {
41 
42 namespace BackboneRouter {
43 
44 RegisterLogModule("BbrMlt");
45 
Add(const Ip6::Address & aAddress,Time aExpireTime)46 Error MulticastListenersTable::Add(const Ip6::Address &aAddress, Time aExpireTime)
47 {
48     Error error = kErrorNone;
49 
50     VerifyOrExit(aAddress.IsMulticastLargerThanRealmLocal(), error = kErrorInvalidArgs);
51 
52     for (uint16_t i = 0; i < mNumValidListeners; i++)
53     {
54         Listener &listener = mListeners[i];
55 
56         if (listener.GetAddress() == aAddress)
57         {
58             listener.SetExpireTime(aExpireTime);
59             FixHeap(i);
60             ExitNow();
61         }
62     }
63 
64     VerifyOrExit(mNumValidListeners < GetArrayLength(mListeners), error = kErrorNoBufs);
65 
66     mListeners[mNumValidListeners].SetAddress(aAddress);
67     mListeners[mNumValidListeners].SetExpireTime(aExpireTime);
68     mNumValidListeners++;
69 
70     FixHeap(mNumValidListeners - 1);
71 
72     mCallback.InvokeIfSet(MapEnum(Listener::kEventAdded), &aAddress);
73 
74 exit:
75     Log(kAdd, aAddress, aExpireTime, error);
76     CheckInvariants();
77     return error;
78 }
79 
Remove(const Ip6::Address & aAddress)80 void MulticastListenersTable::Remove(const Ip6::Address &aAddress)
81 {
82     Error error = kErrorNotFound;
83 
84     for (uint16_t i = 0; i < mNumValidListeners; i++)
85     {
86         Listener &listener = mListeners[i];
87 
88         if (listener.GetAddress() == aAddress)
89         {
90             mNumValidListeners--;
91 
92             if (i != mNumValidListeners)
93             {
94                 listener = mListeners[mNumValidListeners];
95                 FixHeap(i);
96             }
97 
98             mCallback.InvokeIfSet(MapEnum(Listener::kEventRemoved), &aAddress);
99 
100             ExitNow(error = kErrorNone);
101         }
102     }
103 
104 exit:
105     Log(kRemove, aAddress, TimeMilli(0), error);
106     CheckInvariants();
107 }
108 
Expire(void)109 void MulticastListenersTable::Expire(void)
110 {
111     TimeMilli    now = TimerMilli::GetNow();
112     Ip6::Address address;
113 
114     while (mNumValidListeners > 0 && now >= mListeners[0].GetExpireTime())
115     {
116         Log(kExpire, mListeners[0].GetAddress(), mListeners[0].GetExpireTime(), kErrorNone);
117         address = mListeners[0].GetAddress();
118 
119         mNumValidListeners--;
120 
121         if (mNumValidListeners > 0)
122         {
123             mListeners[0] = mListeners[mNumValidListeners];
124             FixHeap(0);
125         }
126 
127         mCallback.InvokeIfSet(MapEnum(Listener::kEventRemoved), &address);
128     }
129 
130     CheckInvariants();
131 }
132 
133 #if OT_SHOULD_LOG_AT(OT_LOG_LEVEL_DEBG)
Log(Action aAction,const Ip6::Address & aAddress,TimeMilli aExpireTime,Error aError) const134 void MulticastListenersTable::Log(Action              aAction,
135                                   const Ip6::Address &aAddress,
136                                   TimeMilli           aExpireTime,
137                                   Error               aError) const
138 {
139     static const char *const kActionStrings[] = {
140         "Add",    // (0) kAdd
141         "Remove", // (1) kRemove
142         "Expire", // (2) kExpire
143     };
144 
145     struct EnumCheck
146     {
147         InitEnumValidatorCounter();
148         ValidateNextEnum(kAdd);
149         ValidateNextEnum(kRemove);
150         ValidateNextEnum(kExpire);
151     };
152 
153     LogDebg("%s %s expire %lu: %s", kActionStrings[aAction], aAddress.ToString().AsCString(),
154             ToUlong(aExpireTime.GetValue()), ErrorToString(aError));
155 }
156 #else
Log(Action,const Ip6::Address &,TimeMilli,Error) const157 void MulticastListenersTable::Log(Action, const Ip6::Address &, TimeMilli, Error) const {}
158 #endif
159 
FixHeap(uint16_t aIndex)160 void MulticastListenersTable::FixHeap(uint16_t aIndex)
161 {
162     if (!SiftHeapElemDown(aIndex))
163     {
164         SiftHeapElemUp(aIndex);
165     }
166 }
167 
CheckInvariants(void) const168 void MulticastListenersTable::CheckInvariants(void) const
169 {
170 #if OPENTHREAD_EXAMPLES_SIMULATION && OPENTHREAD_CONFIG_ASSERT_ENABLE
171     for (uint16_t child = 1; child < mNumValidListeners; ++child)
172     {
173         uint16_t parent = (child - 1) / 2;
174 
175         OT_ASSERT(!(mListeners[child] < mListeners[parent]));
176     }
177 #endif
178 }
179 
SiftHeapElemDown(uint16_t aIndex)180 bool MulticastListenersTable::SiftHeapElemDown(uint16_t aIndex)
181 {
182     uint16_t index = aIndex;
183     Listener saveElem;
184 
185     OT_ASSERT(aIndex < mNumValidListeners);
186 
187     saveElem = mListeners[aIndex];
188 
189     for (;;)
190     {
191         uint16_t child = 2 * index + 1;
192 
193         if (child >= mNumValidListeners || child <= index) // child <= index after int overflow
194         {
195             break;
196         }
197 
198         if (child + 1 < mNumValidListeners && mListeners[child + 1] < mListeners[child])
199         {
200             child++;
201         }
202 
203         if (!(mListeners[child] < saveElem))
204         {
205             break;
206         }
207 
208         mListeners[index] = mListeners[child];
209 
210         index = child;
211     }
212 
213     if (index > aIndex)
214     {
215         mListeners[index] = saveElem;
216     }
217 
218     return index > aIndex;
219 }
220 
SiftHeapElemUp(uint16_t aIndex)221 void MulticastListenersTable::SiftHeapElemUp(uint16_t aIndex)
222 {
223     uint16_t index = aIndex;
224     Listener saveElem;
225 
226     OT_ASSERT(aIndex < mNumValidListeners);
227 
228     saveElem = mListeners[aIndex];
229 
230     for (;;)
231     {
232         uint16_t parent = (index - 1) / 2;
233 
234         if (index == 0 || !(saveElem < mListeners[parent]))
235         {
236             break;
237         }
238 
239         mListeners[index] = mListeners[parent];
240 
241         index = parent;
242     }
243 
244     if (index < aIndex)
245     {
246         mListeners[index] = saveElem;
247     }
248 }
249 
begin(void)250 MulticastListenersTable::Listener *MulticastListenersTable::IteratorBuilder::begin(void)
251 {
252     return &Get<MulticastListenersTable>().mListeners[0];
253 }
254 
end(void)255 MulticastListenersTable::Listener *MulticastListenersTable::IteratorBuilder::end(void)
256 {
257     return &Get<MulticastListenersTable>().mListeners[Get<MulticastListenersTable>().mNumValidListeners];
258 }
259 
Clear(void)260 void MulticastListenersTable::Clear(void)
261 {
262     if (mCallback.IsSet())
263     {
264         for (uint16_t i = 0; i < mNumValidListeners; i++)
265         {
266             mCallback.Invoke(MapEnum(Listener::kEventRemoved), &mListeners[i].GetAddress());
267         }
268     }
269 
270     mNumValidListeners = 0;
271 
272     CheckInvariants();
273 }
274 
SetCallback(Listener::Callback aCallback,void * aContext)275 void MulticastListenersTable::SetCallback(Listener::Callback aCallback, void *aContext)
276 {
277     mCallback.Set(aCallback, aContext);
278 
279     if (mCallback.IsSet())
280     {
281         for (uint16_t i = 0; i < mNumValidListeners; i++)
282         {
283             mCallback.Invoke(MapEnum(Listener::kEventAdded), &mListeners[i].GetAddress());
284         }
285     }
286 }
287 
GetNext(Listener::Iterator & aIterator,Listener::Info & aInfo)288 Error MulticastListenersTable::GetNext(Listener::Iterator &aIterator, Listener::Info &aInfo)
289 {
290     Error     error = kErrorNone;
291     TimeMilli now;
292 
293     VerifyOrExit(aIterator < mNumValidListeners, error = kErrorNotFound);
294 
295     now = TimerMilli::GetNow();
296 
297     aInfo.mAddress = mListeners[aIterator].mAddress;
298     aInfo.mTimeout =
299         Time::MsecToSec(mListeners[aIterator].mExpireTime > now ? mListeners[aIterator].mExpireTime - now : 0);
300 
301     aIterator++;
302 
303 exit:
304     return error;
305 }
306 
307 } // namespace BackboneRouter
308 
309 } // namespace ot
310 
311 #endif // OPENTHREAD_FTD && OPENTHREAD_CONFIG_BACKBONE_ROUTER_MULTICAST_ROUTING_ENABLE
312