1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the message buffer pool and message buffers.
32  */
33 
34 #include "message.hpp"
35 
36 #include "common/as_core_type.hpp"
37 #include "common/code_utils.hpp"
38 #include "common/debug.hpp"
39 #include "common/heap.hpp"
40 #include "common/locator_getters.hpp"
41 #include "common/log.hpp"
42 #include "common/num_utils.hpp"
43 #include "common/numeric_limits.hpp"
44 #include "instance/instance.hpp"
45 #include "net/checksum.hpp"
46 #include "net/ip6.hpp"
47 
48 #if OPENTHREAD_MTD || OPENTHREAD_FTD
49 
50 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE && OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
51 #error "OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE conflicts with OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT."
52 #endif
53 
54 namespace ot {
55 
56 RegisterLogModule("Message");
57 
58 //---------------------------------------------------------------------------------------------------------------------
59 // MessagePool
60 
MessagePool(Instance & aInstance)61 MessagePool::MessagePool(Instance &aInstance)
62     : InstanceLocator(aInstance)
63     , mNumAllocated(0)
64     , mMaxAllocated(0)
65 {
66 #if OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
67     otPlatMessagePoolInit(&GetInstance(), kNumBuffers, sizeof(Buffer));
68 #endif
69 }
70 
Allocate(Message::Type aType,uint16_t aReserveHeader,const Message::Settings & aSettings)71 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader, const Message::Settings &aSettings)
72 {
73     Error    error = kErrorNone;
74     Message *message;
75 
76     VerifyOrExit((message = static_cast<Message *>(NewBuffer(aSettings.GetPriority()))) != nullptr);
77 
78     ClearAllBytes(*message);
79     message->SetMessagePool(this);
80     message->SetType(aType);
81     message->SetReserved(aReserveHeader);
82     message->SetLinkSecurityEnabled(aSettings.IsLinkSecurityEnabled());
83     message->SetLoopbackToHostAllowed(OPENTHREAD_CONFIG_IP6_ALLOW_LOOP_BACK_HOST_DATAGRAMS);
84     message->SetOrigin(Message::kOriginHostTrusted);
85 
86     SuccessOrExit(error = message->SetPriority(aSettings.GetPriority()));
87     SuccessOrExit(error = message->SetLength(0));
88 
89 exit:
90     if (error != kErrorNone)
91     {
92         Free(message);
93         message = nullptr;
94     }
95 
96     return message;
97 }
98 
Allocate(Message::Type aType)99 Message *MessagePool::Allocate(Message::Type aType) { return Allocate(aType, 0, Message::Settings::GetDefault()); }
100 
Allocate(Message::Type aType,uint16_t aReserveHeader)101 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader)
102 {
103     return Allocate(aType, aReserveHeader, Message::Settings::GetDefault());
104 }
105 
Free(Message * aMessage)106 void MessagePool::Free(Message *aMessage)
107 {
108     OT_ASSERT(aMessage->Next() == nullptr && aMessage->Prev() == nullptr);
109 
110     FreeBuffers(static_cast<Buffer *>(aMessage));
111 }
112 
NewBuffer(Message::Priority aPriority)113 Buffer *MessagePool::NewBuffer(Message::Priority aPriority)
114 {
115     Buffer *buffer = nullptr;
116 
117     while ((
118 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
119                buffer = static_cast<Buffer *>(Heap::CAlloc(1, sizeof(Buffer)))
120 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
121                buffer = static_cast<Buffer *>(otPlatMessagePoolNew(&GetInstance()))
122 #else
123                buffer = mBufferPool.Allocate()
124 #endif
125                    ) == nullptr)
126     {
127         SuccessOrExit(ReclaimBuffers(aPriority));
128     }
129 
130     mNumAllocated++;
131     mMaxAllocated = Max(mMaxAllocated, mNumAllocated);
132 
133     buffer->SetNextBuffer(nullptr);
134 
135 exit:
136     if (buffer == nullptr)
137     {
138         LogInfo("No available message buffer");
139     }
140 
141     return buffer;
142 }
143 
FreeBuffers(Buffer * aBuffer)144 void MessagePool::FreeBuffers(Buffer *aBuffer)
145 {
146     while (aBuffer != nullptr)
147     {
148         Buffer *next = aBuffer->GetNextBuffer();
149 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
150         Heap::Free(aBuffer);
151 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
152         otPlatMessagePoolFree(&GetInstance(), aBuffer);
153 #else
154         mBufferPool.Free(*aBuffer);
155 #endif
156         mNumAllocated--;
157 
158         aBuffer = next;
159     }
160 }
161 
ReclaimBuffers(Message::Priority aPriority)162 Error MessagePool::ReclaimBuffers(Message::Priority aPriority) { return Get<MeshForwarder>().EvictMessage(aPriority); }
163 
GetFreeBufferCount(void) const164 uint16_t MessagePool::GetFreeBufferCount(void) const
165 {
166     uint16_t rval;
167 
168 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
169 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
170     rval = static_cast<uint16_t>(Instance::GetHeap().GetFreeSize() / sizeof(Buffer));
171 #else
172     rval = NumericLimits<uint16_t>::kMax;
173 #endif
174 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
175     rval = otPlatMessagePoolNumFreeBuffers(&GetInstance());
176 #else
177     rval = kNumBuffers - mNumAllocated;
178 #endif
179 
180     return rval;
181 }
182 
GetTotalBufferCount(void) const183 uint16_t MessagePool::GetTotalBufferCount(void) const
184 {
185     uint16_t rval;
186 
187 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
188 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
189     rval = static_cast<uint16_t>(Instance::GetHeap().GetCapacity() / sizeof(Buffer));
190 #else
191     rval = NumericLimits<uint16_t>::kMax;
192 #endif
193 #else
194     rval = OPENTHREAD_CONFIG_NUM_MESSAGE_BUFFERS;
195 #endif
196 
197     return rval;
198 }
199 
200 //---------------------------------------------------------------------------------------------------------------------
201 // Message::Settings
202 
203 const otMessageSettings Message::Settings::kDefault = {kWithLinkSecurity, kPriorityNormal};
204 
Settings(LinkSecurityMode aSecurityMode,Priority aPriority)205 Message::Settings::Settings(LinkSecurityMode aSecurityMode, Priority aPriority)
206 {
207     mLinkSecurityEnabled = aSecurityMode;
208     mPriority            = aPriority;
209 }
210 
From(const otMessageSettings * aSettings)211 const Message::Settings &Message::Settings::From(const otMessageSettings *aSettings)
212 {
213     return (aSettings == nullptr) ? GetDefault() : AsCoreType(aSettings);
214 }
215 
216 //---------------------------------------------------------------------------------------------------------------------
217 // Message::Iterator
218 
Advance(void)219 void Message::Iterator::Advance(void)
220 {
221     mItem = mNext;
222     mNext = NextMessage(mNext);
223 }
224 
225 //---------------------------------------------------------------------------------------------------------------------
226 // Message
227 
ResizeMessage(uint16_t aLength)228 Error Message::ResizeMessage(uint16_t aLength)
229 {
230     // This method adds or frees message buffers to meet the
231     // requested length.
232 
233     Error    error     = kErrorNone;
234     Buffer  *curBuffer = this;
235     Buffer  *lastBuffer;
236     uint16_t curLength = kHeadBufferDataSize;
237 
238     while (curLength < aLength)
239     {
240         if (curBuffer->GetNextBuffer() == nullptr)
241         {
242             curBuffer->SetNextBuffer(GetMessagePool()->NewBuffer(GetPriority()));
243             VerifyOrExit(curBuffer->GetNextBuffer() != nullptr, error = kErrorNoBufs);
244         }
245 
246         curBuffer = curBuffer->GetNextBuffer();
247         curLength += kBufferDataSize;
248     }
249 
250     lastBuffer = curBuffer;
251     curBuffer  = curBuffer->GetNextBuffer();
252     lastBuffer->SetNextBuffer(nullptr);
253 
254     GetMessagePool()->FreeBuffers(curBuffer);
255 
256 exit:
257     return error;
258 }
259 
Free(void)260 void Message::Free(void) { GetMessagePool()->Free(this); }
261 
GetNext(void) const262 Message *Message::GetNext(void) const
263 {
264     Message *next;
265     Message *tail;
266 
267     if (GetMetadata().mInPriorityQ)
268     {
269         PriorityQueue *priorityQueue = GetPriorityQueue();
270         VerifyOrExit(priorityQueue != nullptr, next = nullptr);
271         tail = priorityQueue->GetTail();
272     }
273     else
274     {
275         MessageQueue *messageQueue = GetMessageQueue();
276         VerifyOrExit(messageQueue != nullptr, next = nullptr);
277         tail = messageQueue->GetTail();
278     }
279 
280     next = (this == tail) ? nullptr : Next();
281 
282 exit:
283     return next;
284 }
285 
SetLength(uint16_t aLength)286 Error Message::SetLength(uint16_t aLength)
287 {
288     Error    error              = kErrorNone;
289     uint16_t totalLengthRequest = GetReserved() + aLength;
290 
291     VerifyOrExit(totalLengthRequest >= GetReserved(), error = kErrorInvalidArgs);
292 
293     SuccessOrExit(error = ResizeMessage(totalLengthRequest));
294     GetMetadata().mLength = aLength;
295 
296     // Correct the offset in case shorter length is set.
297     if (GetOffset() > aLength)
298     {
299         SetOffset(aLength);
300     }
301 
302 exit:
303     return error;
304 }
305 
GetBufferCount(void) const306 uint8_t Message::GetBufferCount(void) const
307 {
308     uint8_t rval = 1;
309 
310     for (const Buffer *curBuffer = GetNextBuffer(); curBuffer; curBuffer = curBuffer->GetNextBuffer())
311     {
312         rval++;
313     }
314 
315     return rval;
316 }
317 
MoveOffset(int aDelta)318 void Message::MoveOffset(int aDelta)
319 {
320     OT_ASSERT(GetOffset() + aDelta <= GetLength());
321     GetMetadata().mOffset += static_cast<int16_t>(aDelta);
322     OT_ASSERT(GetMetadata().mOffset <= GetLength());
323 }
324 
SetOffset(uint16_t aOffset)325 void Message::SetOffset(uint16_t aOffset)
326 {
327     OT_ASSERT(aOffset <= GetLength());
328     GetMetadata().mOffset = aOffset;
329 }
330 
IsSubTypeMle(void) const331 bool Message::IsSubTypeMle(void) const
332 {
333     bool rval;
334 
335     switch (GetMetadata().mSubType)
336     {
337     case kSubTypeMleGeneral:
338     case kSubTypeMleAnnounce:
339     case kSubTypeMleDiscoverRequest:
340     case kSubTypeMleDiscoverResponse:
341     case kSubTypeMleChildUpdateRequest:
342     case kSubTypeMleDataResponse:
343     case kSubTypeMleChildIdRequest:
344         rval = true;
345         break;
346 
347     default:
348         rval = false;
349         break;
350     }
351 
352     return rval;
353 }
354 
SetPriority(Priority aPriority)355 Error Message::SetPriority(Priority aPriority)
356 {
357     Error          error    = kErrorNone;
358     uint8_t        priority = static_cast<uint8_t>(aPriority);
359     PriorityQueue *priorityQueue;
360 
361     static_assert(kNumPriorities <= 4, "`Metadata::mPriority` as a 2-bit field cannot fit all `Priority` values");
362 
363     VerifyOrExit(priority < kNumPriorities, error = kErrorInvalidArgs);
364 
365     VerifyOrExit(IsInAQueue(), GetMetadata().mPriority = priority);
366     VerifyOrExit(GetMetadata().mPriority != priority);
367 
368     priorityQueue = GetPriorityQueue();
369 
370     if (priorityQueue != nullptr)
371     {
372         priorityQueue->Dequeue(*this);
373     }
374 
375     GetMetadata().mPriority = priority;
376 
377     if (priorityQueue != nullptr)
378     {
379         priorityQueue->Enqueue(*this);
380     }
381 
382 exit:
383     return error;
384 }
385 
PriorityToString(Priority aPriority)386 const char *Message::PriorityToString(Priority aPriority)
387 {
388     static const char *const kPriorityStrings[] = {
389         "low",    // (0) kPriorityLow
390         "normal", // (1) kPriorityNormal
391         "high",   // (2) kPriorityHigh
392         "net",    // (3) kPriorityNet
393     };
394 
395     static_assert(kPriorityLow == 0, "kPriorityLow value is incorrect");
396     static_assert(kPriorityNormal == 1, "kPriorityNormal value is incorrect");
397     static_assert(kPriorityHigh == 2, "kPriorityHigh value is incorrect");
398     static_assert(kPriorityNet == 3, "kPriorityNet value is incorrect");
399 
400     return kPriorityStrings[aPriority];
401 }
402 
AppendBytes(const void * aBuf,uint16_t aLength)403 Error Message::AppendBytes(const void *aBuf, uint16_t aLength)
404 {
405     Error    error     = kErrorNone;
406     uint16_t oldLength = GetLength();
407 
408     SuccessOrExit(error = SetLength(GetLength() + aLength));
409     WriteBytes(oldLength, aBuf, aLength);
410 
411 exit:
412     return error;
413 }
414 
AppendBytesFromMessage(const Message & aMessage,uint16_t aOffset,uint16_t aLength)415 Error Message::AppendBytesFromMessage(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
416 {
417     Error    error       = kErrorNone;
418     uint16_t writeOffset = GetLength();
419     Chunk    chunk;
420 
421     VerifyOrExit(aMessage.GetLength() >= aOffset + aLength, error = kErrorParse);
422     SuccessOrExit(error = SetLength(GetLength() + aLength));
423 
424     aMessage.GetFirstChunk(aOffset, aLength, chunk);
425 
426     while (chunk.GetLength() > 0)
427     {
428         WriteBytes(writeOffset, chunk.GetBytes(), chunk.GetLength());
429         writeOffset += chunk.GetLength();
430         aMessage.GetNextChunk(aLength, chunk);
431     }
432 
433 exit:
434     return error;
435 }
436 
PrependBytes(const void * aBuf,uint16_t aLength)437 Error Message::PrependBytes(const void *aBuf, uint16_t aLength)
438 {
439     Error   error     = kErrorNone;
440     Buffer *newBuffer = nullptr;
441 
442     while (aLength > GetReserved())
443     {
444         VerifyOrExit((newBuffer = GetMessagePool()->NewBuffer(GetPriority())) != nullptr, error = kErrorNoBufs);
445 
446         newBuffer->SetNextBuffer(GetNextBuffer());
447         SetNextBuffer(newBuffer);
448 
449         if (GetReserved() < sizeof(mBuffer.mHead.mData))
450         {
451             // Copy payload from the first buffer.
452             memcpy(newBuffer->mBuffer.mHead.mData + GetReserved(), mBuffer.mHead.mData + GetReserved(),
453                    sizeof(mBuffer.mHead.mData) - GetReserved());
454         }
455 
456         SetReserved(GetReserved() + kBufferDataSize);
457     }
458 
459     SetReserved(GetReserved() - aLength);
460     GetMetadata().mLength += aLength;
461     SetOffset(GetOffset() + aLength);
462 
463     if (aBuf != nullptr)
464     {
465         WriteBytes(0, aBuf, aLength);
466     }
467 
468 exit:
469     return error;
470 }
471 
RemoveHeader(uint16_t aLength)472 void Message::RemoveHeader(uint16_t aLength)
473 {
474     OT_ASSERT(aLength <= GetMetadata().mLength);
475 
476     GetMetadata().mReserved += aLength;
477     GetMetadata().mLength -= aLength;
478 
479     if (GetMetadata().mOffset > aLength)
480     {
481         GetMetadata().mOffset -= aLength;
482     }
483     else
484     {
485         GetMetadata().mOffset = 0;
486     }
487 }
488 
RemoveHeader(uint16_t aOffset,uint16_t aLength)489 void Message::RemoveHeader(uint16_t aOffset, uint16_t aLength)
490 {
491     // To shrink the header, we copy the header byte before `aOffset`
492     // forward. Starting at offset `aLength`, we write bytes we read
493     // from offset `0` onward and copy a total of `aOffset` bytes.
494     // Then remove the first `aLength` bytes from message.
495     //
496     //
497     // 0                   aOffset  aOffset + aLength
498     // +-----------------------+---------+------------------------+
499     // | / / / / / / / / / / / | x x x x |                        |
500     // +-----------------------+---------+------------------------+
501     //
502     // 0       aLength                aOffset + aLength
503     // +---------+-----------------------+------------------------+
504     // |         | / / / / / / / / / / / |                        |
505     // +---------+-----------------------+------------------------+
506     //
507     //  0                    aOffset
508     //  +-----------------------+------------------------+
509     //  | / / / / / / / / / / / |                        |
510     //  +-----------------------+------------------------+
511     //
512 
513     WriteBytesFromMessage(/* aWriteOffset */ aLength, *this, /* aReadOffset */ 0, /* aLength */ aOffset);
514     RemoveHeader(aLength);
515 }
516 
InsertHeader(uint16_t aOffset,uint16_t aLength)517 Error Message::InsertHeader(uint16_t aOffset, uint16_t aLength)
518 {
519     Error error;
520 
521     // To make space in header at `aOffset`, we first prepend
522     // `aLength` bytes at front. Then copy the existing bytes
523     // backwards. Starting at offset `0`, we write bytes we read
524     // from offset `aLength` onward and copy a total of `aOffset`
525     // bytes.
526     //
527     // 0                    aOffset
528     // +-----------------------+------------------------+
529     // | / / / / / / / / / / / |                        |
530     // +-----------------------+------------------------+
531     //
532     // 0       aLength                aOffset + aLength
533     // +---------+-----------------------+------------------------+
534     // |         | / / / / / / / / / / / |                        |
535     // +---------+-----------------------+------------------------+
536     //
537     // 0                   aOffset  aOffset + aLength
538     // +-----------------------+---------+------------------------+
539     // | / / / / / / / / / / / |  N E W  |                        |
540     // +-----------------------+---------+------------------------+
541     //
542 
543     SuccessOrExit(error = PrependBytes(nullptr, aLength));
544     WriteBytesFromMessage(/* aWriteOffset */ 0, *this, /* aReadOffset */ aLength, /* aLength */ aOffset);
545 
546 exit:
547     return error;
548 }
549 
RemoveFooter(uint16_t aLength)550 void Message::RemoveFooter(uint16_t aLength) { IgnoreError(SetLength(GetLength() - Min(aLength, GetLength()))); }
551 
GetFirstChunk(uint16_t aOffset,uint16_t & aLength,Chunk & aChunk) const552 void Message::GetFirstChunk(uint16_t aOffset, uint16_t &aLength, Chunk &aChunk) const
553 {
554     // This method gets the first message chunk (contiguous data
555     // buffer) corresponding to a given offset and length. On exit
556     // `aChunk` is updated such that `aChunk.GetBytes()` gives the
557     // pointer to the start of chunk and `aChunk.GetLength()` gives
558     // its length. The `aLength` is also decreased by the chunk
559     // length.
560 
561     VerifyOrExit(aOffset < GetLength(), aChunk.SetLength(0));
562 
563     if (aOffset + aLength >= GetLength())
564     {
565         aLength = GetLength() - aOffset;
566     }
567 
568     aOffset += GetReserved();
569 
570     aChunk.SetBuffer(this);
571 
572     // Special case for the first buffer
573 
574     if (aOffset < kHeadBufferDataSize)
575     {
576         aChunk.Init(GetFirstData() + aOffset, kHeadBufferDataSize - aOffset);
577         ExitNow();
578     }
579 
580     aOffset -= kHeadBufferDataSize;
581 
582     // Find the `Buffer` matching the offset
583 
584     while (true)
585     {
586         aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
587 
588         OT_ASSERT(aChunk.GetBuffer() != nullptr);
589 
590         if (aOffset < kBufferDataSize)
591         {
592             aChunk.Init(aChunk.GetBuffer()->GetData() + aOffset, kBufferDataSize - aOffset);
593             ExitNow();
594         }
595 
596         aOffset -= kBufferDataSize;
597     }
598 
599 exit:
600     if (aChunk.GetLength() > aLength)
601     {
602         aChunk.SetLength(aLength);
603     }
604 
605     aLength -= aChunk.GetLength();
606 }
607 
GetNextChunk(uint16_t & aLength,Chunk & aChunk) const608 void Message::GetNextChunk(uint16_t &aLength, Chunk &aChunk) const
609 {
610     // This method gets the next message chunk. On input, the
611     // `aChunk` should be the previous chunk. On exit, it is
612     // updated to provide info about next chunk, and `aLength`
613     // is decreased by the chunk length. If there is no more
614     // chunk, `aChunk.GetLength()` would be zero.
615 
616     VerifyOrExit(aLength > 0, aChunk.SetLength(0));
617 
618     aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
619 
620     OT_ASSERT(aChunk.GetBuffer() != nullptr);
621 
622     aChunk.Init(aChunk.GetBuffer()->GetData(), kBufferDataSize);
623 
624     if (aChunk.GetLength() > aLength)
625     {
626         aChunk.SetLength(aLength);
627     }
628 
629     aLength -= aChunk.GetLength();
630 
631 exit:
632     return;
633 }
634 
ReadBytes(uint16_t aOffset,void * aBuf,uint16_t aLength) const635 uint16_t Message::ReadBytes(uint16_t aOffset, void *aBuf, uint16_t aLength) const
636 {
637     uint8_t *bufPtr = reinterpret_cast<uint8_t *>(aBuf);
638     Chunk    chunk;
639 
640     GetFirstChunk(aOffset, aLength, chunk);
641 
642     while (chunk.GetLength() > 0)
643     {
644         chunk.CopyBytesTo(bufPtr);
645         bufPtr += chunk.GetLength();
646         GetNextChunk(aLength, chunk);
647     }
648 
649     return static_cast<uint16_t>(bufPtr - reinterpret_cast<uint8_t *>(aBuf));
650 }
651 
Read(uint16_t aOffset,void * aBuf,uint16_t aLength) const652 Error Message::Read(uint16_t aOffset, void *aBuf, uint16_t aLength) const
653 {
654     return (ReadBytes(aOffset, aBuf, aLength) == aLength) ? kErrorNone : kErrorParse;
655 }
656 
CompareBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength,ByteMatcher aMatcher) const657 bool Message::CompareBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength, ByteMatcher aMatcher) const
658 {
659     uint16_t       bytesToCompare = aLength;
660     const uint8_t *bufPtr         = reinterpret_cast<const uint8_t *>(aBuf);
661     Chunk          chunk;
662 
663     GetFirstChunk(aOffset, aLength, chunk);
664 
665     while (chunk.GetLength() > 0)
666     {
667         VerifyOrExit(chunk.MatchesBytesIn(bufPtr, aMatcher));
668         bufPtr += chunk.GetLength();
669         bytesToCompare -= chunk.GetLength();
670         GetNextChunk(aLength, chunk);
671     }
672 
673 exit:
674     return (bytesToCompare == 0);
675 }
676 
CompareBytes(uint16_t aOffset,const Message & aOtherMessage,uint16_t aOtherOffset,uint16_t aLength,ByteMatcher aMatcher) const677 bool Message::CompareBytes(uint16_t       aOffset,
678                            const Message &aOtherMessage,
679                            uint16_t       aOtherOffset,
680                            uint16_t       aLength,
681                            ByteMatcher    aMatcher) const
682 {
683     uint16_t bytesToCompare = aLength;
684     Chunk    chunk;
685 
686     GetFirstChunk(aOffset, aLength, chunk);
687 
688     while (chunk.GetLength() > 0)
689     {
690         VerifyOrExit(aOtherMessage.CompareBytes(aOtherOffset, chunk.GetBytes(), chunk.GetLength(), aMatcher));
691         aOtherOffset += chunk.GetLength();
692         bytesToCompare -= chunk.GetLength();
693         GetNextChunk(aLength, chunk);
694     }
695 
696 exit:
697     return (bytesToCompare == 0);
698 }
699 
WriteBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength)700 void Message::WriteBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength)
701 {
702     const uint8_t *bufPtr = reinterpret_cast<const uint8_t *>(aBuf);
703     MutableChunk   chunk;
704 
705     OT_ASSERT(aOffset + aLength <= GetLength());
706 
707     GetFirstChunk(aOffset, aLength, chunk);
708 
709     while (chunk.GetLength() > 0)
710     {
711         memmove(chunk.GetBytes(), bufPtr, chunk.GetLength());
712         bufPtr += chunk.GetLength();
713         GetNextChunk(aLength, chunk);
714     }
715 }
716 
WriteBytesFromMessage(uint16_t aWriteOffset,const Message & aMessage,uint16_t aReadOffset,uint16_t aLength)717 void Message::WriteBytesFromMessage(uint16_t       aWriteOffset,
718                                     const Message &aMessage,
719                                     uint16_t       aReadOffset,
720                                     uint16_t       aLength)
721 {
722     if ((&aMessage != this) || (aReadOffset >= aWriteOffset))
723     {
724         Chunk chunk;
725 
726         aMessage.GetFirstChunk(aReadOffset, aLength, chunk);
727 
728         while (chunk.GetLength() > 0)
729         {
730             WriteBytes(aWriteOffset, chunk.GetBytes(), chunk.GetLength());
731             aWriteOffset += chunk.GetLength();
732             aMessage.GetNextChunk(aLength, chunk);
733         }
734     }
735     else
736     {
737         // We are copying bytes within the same message forward.
738         // To ensure copy forward works, we read and write from
739         // end of range and move backwards.
740 
741         static constexpr uint16_t kBufSize = 32;
742 
743         uint8_t buf[kBufSize];
744 
745         aWriteOffset += aLength;
746         aReadOffset += aLength;
747 
748         while (aLength > 0)
749         {
750             uint16_t copyLength = Min(kBufSize, aLength);
751 
752             aLength -= copyLength;
753             aReadOffset -= copyLength;
754             aWriteOffset -= copyLength;
755 
756             ReadBytes(aReadOffset, buf, copyLength);
757             WriteBytes(aWriteOffset, buf, copyLength);
758         }
759     }
760 }
761 
Clone(uint16_t aLength) const762 Message *Message::Clone(uint16_t aLength) const
763 {
764     Error    error = kErrorNone;
765     Message *messageCopy;
766     Settings settings(IsLinkSecurityEnabled() ? kWithLinkSecurity : kNoLinkSecurity, GetPriority());
767     uint16_t offset;
768 
769     aLength     = Min(GetLength(), aLength);
770     messageCopy = GetMessagePool()->Allocate(GetType(), GetReserved(), settings);
771     VerifyOrExit(messageCopy != nullptr, error = kErrorNoBufs);
772     SuccessOrExit(error = messageCopy->AppendBytesFromMessage(*this, 0, aLength));
773 
774     // Copy selected message information.
775 
776     offset = Min(GetOffset(), aLength);
777     messageCopy->SetOffset(offset);
778 
779     messageCopy->SetSubType(GetSubType());
780     messageCopy->SetLoopbackToHostAllowed(IsLoopbackToHostAllowed());
781     messageCopy->SetOrigin(GetOrigin());
782     messageCopy->SetTimestamp(GetTimestamp());
783     messageCopy->SetMeshDest(GetMeshDest());
784     messageCopy->SetPanId(GetPanId());
785     messageCopy->SetChannel(GetChannel());
786     messageCopy->SetRssAverager(GetRssAverager());
787     messageCopy->SetLqiAverager(GetLqiAverager());
788 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
789     messageCopy->SetTimeSync(IsTimeSync());
790 #endif
791 
792 exit:
793     FreeAndNullMessageOnError(messageCopy, error);
794     return messageCopy;
795 }
796 
797 #if OPENTHREAD_FTD
GetChildMask(uint16_t aChildIndex) const798 bool Message::GetChildMask(uint16_t aChildIndex) const { return GetMetadata().mChildMask.Get(aChildIndex); }
799 
ClearChildMask(uint16_t aChildIndex)800 void Message::ClearChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, false); }
801 
SetChildMask(uint16_t aChildIndex)802 void Message::SetChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, true); }
803 
IsChildPending(void) const804 bool Message::IsChildPending(void) const { return GetMetadata().mChildMask.HasAny(); }
805 #endif
806 
GetLinkInfo(ThreadLinkInfo & aLinkInfo) const807 Error Message::GetLinkInfo(ThreadLinkInfo &aLinkInfo) const
808 {
809     Error error = kErrorNone;
810 
811     VerifyOrExit(IsOriginThreadNetif(), error = kErrorNotFound);
812 
813     aLinkInfo.Clear();
814 
815     aLinkInfo.mPanId               = GetPanId();
816     aLinkInfo.mChannel             = GetChannel();
817     aLinkInfo.mRss                 = GetAverageRss();
818     aLinkInfo.mLqi                 = GetAverageLqi();
819     aLinkInfo.mLinkSecurity        = IsLinkSecurityEnabled();
820     aLinkInfo.mIsDstPanIdBroadcast = IsDstPanIdBroadcast();
821 
822 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
823     aLinkInfo.mTimeSyncSeq       = GetTimeSyncSeq();
824     aLinkInfo.mNetworkTimeOffset = GetNetworkTimeOffset();
825 #endif
826 
827 #if OPENTHREAD_CONFIG_MULTI_RADIO
828     aLinkInfo.mRadioType = GetRadioType();
829 #endif
830 
831 exit:
832     return error;
833 }
834 
UpdateLinkInfoFrom(const ThreadLinkInfo & aLinkInfo)835 void Message::UpdateLinkInfoFrom(const ThreadLinkInfo &aLinkInfo)
836 {
837     SetPanId(aLinkInfo.mPanId);
838     SetChannel(aLinkInfo.mChannel);
839     AddRss(aLinkInfo.mRss);
840     AddLqi(aLinkInfo.mLqi);
841     SetLinkSecurityEnabled(aLinkInfo.mLinkSecurity);
842     GetMetadata().mIsDstPanIdBroadcast = aLinkInfo.IsDstPanIdBroadcast();
843 
844 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
845     SetTimeSyncSeq(aLinkInfo.mTimeSyncSeq);
846     SetNetworkTimeOffset(aLinkInfo.mNetworkTimeOffset);
847 #endif
848 
849 #if OPENTHREAD_CONFIG_MULTI_RADIO
850     SetRadioType(static_cast<Mac::RadioType>(aLinkInfo.mRadioType));
851 #endif
852 }
853 
IsTimeSync(void) const854 bool Message::IsTimeSync(void) const
855 {
856 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
857     return GetMetadata().mTimeSync;
858 #else
859     return false;
860 #endif
861 }
862 
SetMessageQueue(MessageQueue * aMessageQueue)863 void Message::SetMessageQueue(MessageQueue *aMessageQueue)
864 {
865     GetMetadata().mQueue       = aMessageQueue;
866     GetMetadata().mInPriorityQ = false;
867 }
868 
SetPriorityQueue(PriorityQueue * aPriorityQueue)869 void Message::SetPriorityQueue(PriorityQueue *aPriorityQueue)
870 {
871     GetMetadata().mQueue       = aPriorityQueue;
872     GetMetadata().mInPriorityQ = true;
873 }
874 
875 //---------------------------------------------------------------------------------------------------------------------
876 // MessageQueue
877 
Enqueue(Message & aMessage,QueuePosition aPosition)878 void MessageQueue::Enqueue(Message &aMessage, QueuePosition aPosition)
879 {
880     OT_ASSERT(!aMessage.IsInAQueue());
881     OT_ASSERT((aMessage.Next() == nullptr) && (aMessage.Prev() == nullptr));
882 
883     aMessage.SetMessageQueue(this);
884 
885     if (GetTail() == nullptr)
886     {
887         aMessage.Next() = &aMessage;
888         aMessage.Prev() = &aMessage;
889 
890         SetTail(&aMessage);
891     }
892     else
893     {
894         Message *head = GetTail()->Next();
895 
896         aMessage.Next() = head;
897         aMessage.Prev() = GetTail();
898 
899         head->Prev()      = &aMessage;
900         GetTail()->Next() = &aMessage;
901 
902         if (aPosition == kQueuePositionTail)
903         {
904             SetTail(&aMessage);
905         }
906     }
907 }
908 
Dequeue(Message & aMessage)909 void MessageQueue::Dequeue(Message &aMessage)
910 {
911     OT_ASSERT(aMessage.GetMessageQueue() == this);
912     OT_ASSERT((aMessage.Next() != nullptr) && (aMessage.Prev() != nullptr));
913 
914     if (&aMessage == GetTail())
915     {
916         SetTail(GetTail()->Prev());
917 
918         if (&aMessage == GetTail())
919         {
920             SetTail(nullptr);
921         }
922     }
923 
924     aMessage.Prev()->Next() = aMessage.Next();
925     aMessage.Next()->Prev() = aMessage.Prev();
926 
927     aMessage.Prev() = nullptr;
928     aMessage.Next() = nullptr;
929 
930     aMessage.SetMessageQueue(nullptr);
931 }
932 
DequeueAndFree(Message & aMessage)933 void MessageQueue::DequeueAndFree(Message &aMessage)
934 {
935     Dequeue(aMessage);
936     aMessage.Free();
937 }
938 
DequeueAndFreeAll(void)939 void MessageQueue::DequeueAndFreeAll(void)
940 {
941     Message *message;
942 
943     while ((message = GetHead()) != nullptr)
944     {
945         DequeueAndFree(*message);
946     }
947 }
948 
begin(void)949 Message::Iterator MessageQueue::begin(void) { return Message::Iterator(GetHead()); }
950 
begin(void) const951 Message::ConstIterator MessageQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
952 
GetInfo(Info & aInfo) const953 void MessageQueue::GetInfo(Info &aInfo) const
954 {
955     for (const Message &message : *this)
956     {
957         aInfo.mNumMessages++;
958         aInfo.mNumBuffers += message.GetBufferCount();
959         aInfo.mTotalBytes += message.GetLength();
960     }
961 }
962 
963 //---------------------------------------------------------------------------------------------------------------------
964 // PriorityQueue
965 
FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const966 const Message *PriorityQueue::FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const
967 {
968     // Find the first non-`nullptr` tail starting from the given priority
969     // level and moving forward (wrapping from priority value
970     // `kNumPriorities` -1 back to 0).
971 
972     const Message *tail = nullptr;
973     uint8_t        priority;
974 
975     priority = static_cast<uint8_t>(aStartPriorityLevel);
976 
977     do
978     {
979         if (mTails[priority] != nullptr)
980         {
981             tail = mTails[priority];
982             break;
983         }
984 
985         priority = PrevPriority(priority);
986     } while (priority != aStartPriorityLevel);
987 
988     return tail;
989 }
990 
GetHead(void) const991 const Message *PriorityQueue::GetHead(void) const
992 {
993     return Message::NextOf(FindFirstNonNullTail(Message::kPriorityLow));
994 }
995 
GetHeadForPriority(Message::Priority aPriority) const996 const Message *PriorityQueue::GetHeadForPriority(Message::Priority aPriority) const
997 {
998     const Message *head;
999     const Message *previousTail;
1000 
1001     if (mTails[aPriority] != nullptr)
1002     {
1003         previousTail = FindFirstNonNullTail(static_cast<Message::Priority>(PrevPriority(aPriority)));
1004 
1005         OT_ASSERT(previousTail != nullptr);
1006 
1007         head = previousTail->Next();
1008     }
1009     else
1010     {
1011         head = nullptr;
1012     }
1013 
1014     return head;
1015 }
1016 
GetTail(void) const1017 const Message *PriorityQueue::GetTail(void) const { return FindFirstNonNullTail(Message::kPriorityLow); }
1018 
Enqueue(Message & aMessage)1019 void PriorityQueue::Enqueue(Message &aMessage)
1020 {
1021     Message::Priority priority;
1022     Message          *tail;
1023     Message          *next;
1024 
1025     OT_ASSERT(!aMessage.IsInAQueue());
1026 
1027     aMessage.SetPriorityQueue(this);
1028 
1029     priority = aMessage.GetPriority();
1030 
1031     tail = FindFirstNonNullTail(priority);
1032 
1033     if (tail != nullptr)
1034     {
1035         next = tail->Next();
1036 
1037         aMessage.Next() = next;
1038         aMessage.Prev() = tail;
1039         next->Prev()    = &aMessage;
1040         tail->Next()    = &aMessage;
1041     }
1042     else
1043     {
1044         aMessage.Next() = &aMessage;
1045         aMessage.Prev() = &aMessage;
1046     }
1047 
1048     mTails[priority] = &aMessage;
1049 }
1050 
Dequeue(Message & aMessage)1051 void PriorityQueue::Dequeue(Message &aMessage)
1052 {
1053     Message::Priority priority;
1054     Message          *tail;
1055 
1056     OT_ASSERT(aMessage.GetPriorityQueue() == this);
1057 
1058     priority = aMessage.GetPriority();
1059 
1060     tail = mTails[priority];
1061 
1062     if (&aMessage == tail)
1063     {
1064         tail = tail->Prev();
1065 
1066         if ((&aMessage == tail) || (tail->GetPriority() != priority))
1067         {
1068             tail = nullptr;
1069         }
1070 
1071         mTails[priority] = tail;
1072     }
1073 
1074     aMessage.Next()->Prev() = aMessage.Prev();
1075     aMessage.Prev()->Next() = aMessage.Next();
1076     aMessage.Next()         = nullptr;
1077     aMessage.Prev()         = nullptr;
1078 
1079     aMessage.SetPriorityQueue(nullptr);
1080 }
1081 
DequeueAndFree(Message & aMessage)1082 void PriorityQueue::DequeueAndFree(Message &aMessage)
1083 {
1084     Dequeue(aMessage);
1085     aMessage.Free();
1086 }
1087 
DequeueAndFreeAll(void)1088 void PriorityQueue::DequeueAndFreeAll(void)
1089 {
1090     Message *message;
1091 
1092     while ((message = GetHead()) != nullptr)
1093     {
1094         DequeueAndFree(*message);
1095     }
1096 }
1097 
begin(void)1098 Message::Iterator PriorityQueue::begin(void) { return Message::Iterator(GetHead()); }
1099 
begin(void) const1100 Message::ConstIterator PriorityQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
1101 
GetInfo(Info & aInfo) const1102 void PriorityQueue::GetInfo(Info &aInfo) const
1103 {
1104     for (const Message &message : *this)
1105     {
1106         aInfo.mNumMessages++;
1107         aInfo.mNumBuffers += message.GetBufferCount();
1108         aInfo.mTotalBytes += message.GetLength();
1109     }
1110 }
1111 
1112 } // namespace ot
1113 #endif // OPENTHREAD_MTD || OPENTHREAD_FTD
1114