1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the message buffer pool and message buffers.
32  */
33 
34 #include "message.hpp"
35 
36 #include "common/as_core_type.hpp"
37 #include "common/code_utils.hpp"
38 #include "common/debug.hpp"
39 #include "common/heap.hpp"
40 #include "common/locator_getters.hpp"
41 #include "common/log.hpp"
42 #include "common/num_utils.hpp"
43 #include "common/numeric_limits.hpp"
44 #include "instance/instance.hpp"
45 #include "net/checksum.hpp"
46 #include "net/ip6.hpp"
47 
48 #if OPENTHREAD_MTD || OPENTHREAD_FTD
49 
50 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE && OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
51 #error "OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE conflicts with OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT."
52 #endif
53 
54 namespace ot {
55 
56 RegisterLogModule("Message");
57 
58 //---------------------------------------------------------------------------------------------------------------------
59 // MessagePool
60 
MessagePool(Instance & aInstance)61 MessagePool::MessagePool(Instance &aInstance)
62     : InstanceLocator(aInstance)
63     , mNumAllocated(0)
64     , mMaxAllocated(0)
65 {
66 #if OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
67     otPlatMessagePoolInit(&GetInstance(), kNumBuffers, sizeof(Buffer));
68 #endif
69 }
70 
Allocate(Message::Type aType,uint16_t aReserveHeader,const Message::Settings & aSettings)71 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader, const Message::Settings &aSettings)
72 {
73     Error    error = kErrorNone;
74     Message *message;
75 
76     VerifyOrExit((message = static_cast<Message *>(NewBuffer(aSettings.GetPriority()))) != nullptr);
77 
78     memset(message, 0, sizeof(*message));
79     message->SetMessagePool(this);
80     message->SetType(aType);
81     message->SetReserved(aReserveHeader);
82     message->SetLinkSecurityEnabled(aSettings.IsLinkSecurityEnabled());
83     message->SetLoopbackToHostAllowed(OPENTHREAD_CONFIG_IP6_ALLOW_LOOP_BACK_HOST_DATAGRAMS);
84     message->SetOrigin(Message::kOriginHostTrusted);
85 
86     SuccessOrExit(error = message->SetPriority(aSettings.GetPriority()));
87     SuccessOrExit(error = message->SetLength(0));
88 
89 exit:
90     if (error != kErrorNone)
91     {
92         Free(message);
93         message = nullptr;
94     }
95 
96     return message;
97 }
98 
Allocate(Message::Type aType)99 Message *MessagePool::Allocate(Message::Type aType) { return Allocate(aType, 0, Message::Settings::GetDefault()); }
100 
Allocate(Message::Type aType,uint16_t aReserveHeader)101 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader)
102 {
103     return Allocate(aType, aReserveHeader, Message::Settings::GetDefault());
104 }
105 
Free(Message * aMessage)106 void MessagePool::Free(Message *aMessage)
107 {
108     OT_ASSERT(aMessage->Next() == nullptr && aMessage->Prev() == nullptr);
109 
110     FreeBuffers(static_cast<Buffer *>(aMessage));
111 }
112 
NewBuffer(Message::Priority aPriority)113 Buffer *MessagePool::NewBuffer(Message::Priority aPriority)
114 {
115     Buffer *buffer = nullptr;
116 
117     while ((
118 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
119                buffer = static_cast<Buffer *>(Heap::CAlloc(1, sizeof(Buffer)))
120 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
121                buffer = static_cast<Buffer *>(otPlatMessagePoolNew(&GetInstance()))
122 #else
123                buffer = mBufferPool.Allocate()
124 #endif
125                    ) == nullptr)
126     {
127         SuccessOrExit(ReclaimBuffers(aPriority));
128     }
129 
130     mNumAllocated++;
131     mMaxAllocated = Max(mMaxAllocated, mNumAllocated);
132 
133     buffer->SetNextBuffer(nullptr);
134 
135 exit:
136     if (buffer == nullptr)
137     {
138         LogInfo("No available message buffer");
139     }
140 
141     return buffer;
142 }
143 
FreeBuffers(Buffer * aBuffer)144 void MessagePool::FreeBuffers(Buffer *aBuffer)
145 {
146     while (aBuffer != nullptr)
147     {
148         Buffer *next = aBuffer->GetNextBuffer();
149 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
150         Heap::Free(aBuffer);
151 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
152         otPlatMessagePoolFree(&GetInstance(), aBuffer);
153 #else
154         mBufferPool.Free(*aBuffer);
155 #endif
156         mNumAllocated--;
157 
158         aBuffer = next;
159     }
160 }
161 
ReclaimBuffers(Message::Priority aPriority)162 Error MessagePool::ReclaimBuffers(Message::Priority aPriority) { return Get<MeshForwarder>().EvictMessage(aPriority); }
163 
GetFreeBufferCount(void) const164 uint16_t MessagePool::GetFreeBufferCount(void) const
165 {
166     uint16_t rval;
167 
168 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
169 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
170     rval = static_cast<uint16_t>(Instance::GetHeap().GetFreeSize() / sizeof(Buffer));
171 #else
172     rval = NumericLimits<uint16_t>::kMax;
173 #endif
174 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
175     rval = otPlatMessagePoolNumFreeBuffers(&GetInstance());
176 #else
177     rval = kNumBuffers - mNumAllocated;
178 #endif
179 
180     return rval;
181 }
182 
GetTotalBufferCount(void) const183 uint16_t MessagePool::GetTotalBufferCount(void) const
184 {
185     uint16_t rval;
186 
187 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
188 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
189     rval = static_cast<uint16_t>(Instance::GetHeap().GetCapacity() / sizeof(Buffer));
190 #else
191     rval = NumericLimits<uint16_t>::kMax;
192 #endif
193 #else
194     rval = OPENTHREAD_CONFIG_NUM_MESSAGE_BUFFERS;
195 #endif
196 
197     return rval;
198 }
199 
200 //---------------------------------------------------------------------------------------------------------------------
201 // Message::Settings
202 
203 const otMessageSettings Message::Settings::kDefault = {kWithLinkSecurity, kPriorityNormal};
204 
Settings(LinkSecurityMode aSecurityMode,Priority aPriority)205 Message::Settings::Settings(LinkSecurityMode aSecurityMode, Priority aPriority)
206 {
207     mLinkSecurityEnabled = aSecurityMode;
208     mPriority            = aPriority;
209 }
210 
From(const otMessageSettings * aSettings)211 const Message::Settings &Message::Settings::From(const otMessageSettings *aSettings)
212 {
213     return (aSettings == nullptr) ? GetDefault() : AsCoreType(aSettings);
214 }
215 
216 //---------------------------------------------------------------------------------------------------------------------
217 // Message::Iterator
218 
Advance(void)219 void Message::Iterator::Advance(void)
220 {
221     mItem = mNext;
222     mNext = NextMessage(mNext);
223 }
224 
225 //---------------------------------------------------------------------------------------------------------------------
226 // Message
227 
ResizeMessage(uint16_t aLength)228 Error Message::ResizeMessage(uint16_t aLength)
229 {
230     // This method adds or frees message buffers to meet the
231     // requested length.
232 
233     Error    error     = kErrorNone;
234     Buffer  *curBuffer = this;
235     Buffer  *lastBuffer;
236     uint16_t curLength = kHeadBufferDataSize;
237 
238     while (curLength < aLength)
239     {
240         if (curBuffer->GetNextBuffer() == nullptr)
241         {
242             curBuffer->SetNextBuffer(GetMessagePool()->NewBuffer(GetPriority()));
243             VerifyOrExit(curBuffer->GetNextBuffer() != nullptr, error = kErrorNoBufs);
244         }
245 
246         curBuffer = curBuffer->GetNextBuffer();
247         curLength += kBufferDataSize;
248     }
249 
250     lastBuffer = curBuffer;
251     curBuffer  = curBuffer->GetNextBuffer();
252     lastBuffer->SetNextBuffer(nullptr);
253 
254     GetMessagePool()->FreeBuffers(curBuffer);
255 
256 exit:
257     return error;
258 }
259 
Free(void)260 void Message::Free(void) { GetMessagePool()->Free(this); }
261 
GetNext(void) const262 Message *Message::GetNext(void) const
263 {
264     Message *next;
265     Message *tail;
266 
267     if (GetMetadata().mInPriorityQ)
268     {
269         PriorityQueue *priorityQueue = GetPriorityQueue();
270         VerifyOrExit(priorityQueue != nullptr, next = nullptr);
271         tail = priorityQueue->GetTail();
272     }
273     else
274     {
275         MessageQueue *messageQueue = GetMessageQueue();
276         VerifyOrExit(messageQueue != nullptr, next = nullptr);
277         tail = messageQueue->GetTail();
278     }
279 
280     next = (this == tail) ? nullptr : Next();
281 
282 exit:
283     return next;
284 }
285 
SetLength(uint16_t aLength)286 Error Message::SetLength(uint16_t aLength)
287 {
288     Error    error              = kErrorNone;
289     uint16_t totalLengthRequest = GetReserved() + aLength;
290 
291     VerifyOrExit(totalLengthRequest >= GetReserved(), error = kErrorInvalidArgs);
292 
293     SuccessOrExit(error = ResizeMessage(totalLengthRequest));
294     GetMetadata().mLength = aLength;
295 
296     // Correct the offset in case shorter length is set.
297     if (GetOffset() > aLength)
298     {
299         SetOffset(aLength);
300     }
301 
302 exit:
303     return error;
304 }
305 
GetBufferCount(void) const306 uint8_t Message::GetBufferCount(void) const
307 {
308     uint8_t rval = 1;
309 
310     for (const Buffer *curBuffer = GetNextBuffer(); curBuffer; curBuffer = curBuffer->GetNextBuffer())
311     {
312         rval++;
313     }
314 
315     return rval;
316 }
317 
MoveOffset(int aDelta)318 void Message::MoveOffset(int aDelta)
319 {
320     OT_ASSERT(GetOffset() + aDelta <= GetLength());
321     GetMetadata().mOffset += static_cast<int16_t>(aDelta);
322     OT_ASSERT(GetMetadata().mOffset <= GetLength());
323 }
324 
SetOffset(uint16_t aOffset)325 void Message::SetOffset(uint16_t aOffset)
326 {
327     OT_ASSERT(aOffset <= GetLength());
328     GetMetadata().mOffset = aOffset;
329 }
330 
IsSubTypeMle(void) const331 bool Message::IsSubTypeMle(void) const
332 {
333     bool rval;
334 
335     switch (GetMetadata().mSubType)
336     {
337     case kSubTypeMleGeneral:
338     case kSubTypeMleAnnounce:
339     case kSubTypeMleDiscoverRequest:
340     case kSubTypeMleDiscoverResponse:
341     case kSubTypeMleChildUpdateRequest:
342     case kSubTypeMleDataResponse:
343     case kSubTypeMleChildIdRequest:
344         rval = true;
345         break;
346 
347     default:
348         rval = false;
349         break;
350     }
351 
352     return rval;
353 }
354 
SetPriority(Priority aPriority)355 Error Message::SetPriority(Priority aPriority)
356 {
357     Error          error    = kErrorNone;
358     uint8_t        priority = static_cast<uint8_t>(aPriority);
359     PriorityQueue *priorityQueue;
360 
361     static_assert(kNumPriorities <= 4, "`Metadata::mPriority` as a 2-bit field cannot fit all `Priority` values");
362 
363     VerifyOrExit(priority < kNumPriorities, error = kErrorInvalidArgs);
364 
365     VerifyOrExit(IsInAQueue(), GetMetadata().mPriority = priority);
366     VerifyOrExit(GetMetadata().mPriority != priority);
367 
368     priorityQueue = GetPriorityQueue();
369 
370     if (priorityQueue != nullptr)
371     {
372         priorityQueue->Dequeue(*this);
373     }
374 
375     GetMetadata().mPriority = priority;
376 
377     if (priorityQueue != nullptr)
378     {
379         priorityQueue->Enqueue(*this);
380     }
381 
382 exit:
383     return error;
384 }
385 
PriorityToString(Priority aPriority)386 const char *Message::PriorityToString(Priority aPriority)
387 {
388     static const char *const kPriorityStrings[] = {
389         "low",    // (0) kPriorityLow
390         "normal", // (1) kPriorityNormal
391         "high",   // (2) kPriorityHigh
392         "net",    // (3) kPriorityNet
393     };
394 
395     static_assert(kPriorityLow == 0, "kPriorityLow value is incorrect");
396     static_assert(kPriorityNormal == 1, "kPriorityNormal value is incorrect");
397     static_assert(kPriorityHigh == 2, "kPriorityHigh value is incorrect");
398     static_assert(kPriorityNet == 3, "kPriorityNet value is incorrect");
399 
400     return kPriorityStrings[aPriority];
401 }
402 
AppendBytes(const void * aBuf,uint16_t aLength)403 Error Message::AppendBytes(const void *aBuf, uint16_t aLength)
404 {
405     Error    error     = kErrorNone;
406     uint16_t oldLength = GetLength();
407 
408     SuccessOrExit(error = SetLength(GetLength() + aLength));
409     WriteBytes(oldLength, aBuf, aLength);
410 
411 exit:
412     return error;
413 }
414 
AppendBytesFromMessage(const Message & aMessage,uint16_t aOffset,uint16_t aLength)415 Error Message::AppendBytesFromMessage(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
416 {
417     Error    error       = kErrorNone;
418     uint16_t writeOffset = GetLength();
419     Chunk    chunk;
420 
421     VerifyOrExit(aMessage.GetLength() >= aOffset + aLength, error = kErrorParse);
422     SuccessOrExit(error = SetLength(GetLength() + aLength));
423 
424     aMessage.GetFirstChunk(aOffset, aLength, chunk);
425 
426     while (chunk.GetLength() > 0)
427     {
428         WriteBytes(writeOffset, chunk.GetBytes(), chunk.GetLength());
429         writeOffset += chunk.GetLength();
430         aMessage.GetNextChunk(aLength, chunk);
431     }
432 
433 exit:
434     return error;
435 }
436 
PrependBytes(const void * aBuf,uint16_t aLength)437 Error Message::PrependBytes(const void *aBuf, uint16_t aLength)
438 {
439     Error   error     = kErrorNone;
440     Buffer *newBuffer = nullptr;
441 
442     while (aLength > GetReserved())
443     {
444         VerifyOrExit((newBuffer = GetMessagePool()->NewBuffer(GetPriority())) != nullptr, error = kErrorNoBufs);
445 
446         newBuffer->SetNextBuffer(GetNextBuffer());
447         SetNextBuffer(newBuffer);
448 
449         if (GetReserved() < sizeof(mBuffer.mHead.mData))
450         {
451             // Copy payload from the first buffer.
452             memcpy(newBuffer->mBuffer.mHead.mData + GetReserved(), mBuffer.mHead.mData + GetReserved(),
453                    sizeof(mBuffer.mHead.mData) - GetReserved());
454         }
455 
456         SetReserved(GetReserved() + kBufferDataSize);
457     }
458 
459     SetReserved(GetReserved() - aLength);
460     GetMetadata().mLength += aLength;
461     SetOffset(GetOffset() + aLength);
462 
463     if (aBuf != nullptr)
464     {
465         WriteBytes(0, aBuf, aLength);
466     }
467 
468 exit:
469     return error;
470 }
471 
RemoveHeader(uint16_t aLength)472 void Message::RemoveHeader(uint16_t aLength)
473 {
474     OT_ASSERT(aLength <= GetMetadata().mLength);
475 
476     GetMetadata().mReserved += aLength;
477     GetMetadata().mLength -= aLength;
478 
479     if (GetMetadata().mOffset > aLength)
480     {
481         GetMetadata().mOffset -= aLength;
482     }
483     else
484     {
485         GetMetadata().mOffset = 0;
486     }
487 }
488 
RemoveHeader(uint16_t aOffset,uint16_t aLength)489 void Message::RemoveHeader(uint16_t aOffset, uint16_t aLength)
490 {
491     // To shrink the header, we copy the header byte before `aOffset`
492     // forward. Starting at offset `aLength`, we write bytes we read
493     // from offset `0` onward and copy a total of `aOffset` bytes.
494     // Then remove the first `aLength` bytes from message.
495     //
496     //
497     // 0                   aOffset  aOffset + aLength
498     // +-----------------------+---------+------------------------+
499     // | / / / / / / / / / / / | x x x x |                        |
500     // +-----------------------+---------+------------------------+
501     //
502     // 0       aLength                aOffset + aLength
503     // +---------+-----------------------+------------------------+
504     // |         | / / / / / / / / / / / |                        |
505     // +---------+-----------------------+------------------------+
506     //
507     //  0                    aOffset
508     //  +-----------------------+------------------------+
509     //  | / / / / / / / / / / / |                        |
510     //  +-----------------------+------------------------+
511     //
512 
513     WriteBytesFromMessage(/* aWriteOffset */ aLength, *this, /* aReadOffset */ 0, /* aLength */ aOffset);
514     RemoveHeader(aLength);
515 }
516 
InsertHeader(uint16_t aOffset,uint16_t aLength)517 Error Message::InsertHeader(uint16_t aOffset, uint16_t aLength)
518 {
519     Error error;
520 
521     // To make space in header at `aOffset`, we first prepend
522     // `aLength` bytes at front. Then copy the existing bytes
523     // backwards. Starting at offset `0`, we write bytes we read
524     // from offset `aLength` onward and copy a total of `aOffset`
525     // bytes.
526     //
527     // 0                    aOffset
528     // +-----------------------+------------------------+
529     // | / / / / / / / / / / / |                        |
530     // +-----------------------+------------------------+
531     //
532     // 0       aLength                aOffset + aLength
533     // +---------+-----------------------+------------------------+
534     // |         | / / / / / / / / / / / |                        |
535     // +---------+-----------------------+------------------------+
536     //
537     // 0                   aOffset  aOffset + aLength
538     // +-----------------------+---------+------------------------+
539     // | / / / / / / / / / / / |  N E W  |                        |
540     // +-----------------------+---------+------------------------+
541     //
542 
543     SuccessOrExit(error = PrependBytes(nullptr, aLength));
544     WriteBytesFromMessage(/* aWriteOffset */ 0, *this, /* aReadOffset */ aLength, /* aLength */ aOffset);
545 
546 exit:
547     return error;
548 }
549 
GetFirstChunk(uint16_t aOffset,uint16_t & aLength,Chunk & aChunk) const550 void Message::GetFirstChunk(uint16_t aOffset, uint16_t &aLength, Chunk &aChunk) const
551 {
552     // This method gets the first message chunk (contiguous data
553     // buffer) corresponding to a given offset and length. On exit
554     // `aChunk` is updated such that `aChunk.GetBytes()` gives the
555     // pointer to the start of chunk and `aChunk.GetLength()` gives
556     // its length. The `aLength` is also decreased by the chunk
557     // length.
558 
559     VerifyOrExit(aOffset < GetLength(), aChunk.SetLength(0));
560 
561     if (aOffset + aLength >= GetLength())
562     {
563         aLength = GetLength() - aOffset;
564     }
565 
566     aOffset += GetReserved();
567 
568     aChunk.SetBuffer(this);
569 
570     // Special case for the first buffer
571 
572     if (aOffset < kHeadBufferDataSize)
573     {
574         aChunk.Init(GetFirstData() + aOffset, kHeadBufferDataSize - aOffset);
575         ExitNow();
576     }
577 
578     aOffset -= kHeadBufferDataSize;
579 
580     // Find the `Buffer` matching the offset
581 
582     while (true)
583     {
584         aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
585 
586         OT_ASSERT(aChunk.GetBuffer() != nullptr);
587 
588         if (aOffset < kBufferDataSize)
589         {
590             aChunk.Init(aChunk.GetBuffer()->GetData() + aOffset, kBufferDataSize - aOffset);
591             ExitNow();
592         }
593 
594         aOffset -= kBufferDataSize;
595     }
596 
597 exit:
598     if (aChunk.GetLength() > aLength)
599     {
600         aChunk.SetLength(aLength);
601     }
602 
603     aLength -= aChunk.GetLength();
604 }
605 
GetNextChunk(uint16_t & aLength,Chunk & aChunk) const606 void Message::GetNextChunk(uint16_t &aLength, Chunk &aChunk) const
607 {
608     // This method gets the next message chunk. On input, the
609     // `aChunk` should be the previous chunk. On exit, it is
610     // updated to provide info about next chunk, and `aLength`
611     // is decreased by the chunk length. If there is no more
612     // chunk, `aChunk.GetLength()` would be zero.
613 
614     VerifyOrExit(aLength > 0, aChunk.SetLength(0));
615 
616     aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
617 
618     OT_ASSERT(aChunk.GetBuffer() != nullptr);
619 
620     aChunk.Init(aChunk.GetBuffer()->GetData(), kBufferDataSize);
621 
622     if (aChunk.GetLength() > aLength)
623     {
624         aChunk.SetLength(aLength);
625     }
626 
627     aLength -= aChunk.GetLength();
628 
629 exit:
630     return;
631 }
632 
ReadBytes(uint16_t aOffset,void * aBuf,uint16_t aLength) const633 uint16_t Message::ReadBytes(uint16_t aOffset, void *aBuf, uint16_t aLength) const
634 {
635     uint8_t *bufPtr = reinterpret_cast<uint8_t *>(aBuf);
636     Chunk    chunk;
637 
638     GetFirstChunk(aOffset, aLength, chunk);
639 
640     while (chunk.GetLength() > 0)
641     {
642         chunk.CopyBytesTo(bufPtr);
643         bufPtr += chunk.GetLength();
644         GetNextChunk(aLength, chunk);
645     }
646 
647     return static_cast<uint16_t>(bufPtr - reinterpret_cast<uint8_t *>(aBuf));
648 }
649 
Read(uint16_t aOffset,void * aBuf,uint16_t aLength) const650 Error Message::Read(uint16_t aOffset, void *aBuf, uint16_t aLength) const
651 {
652     return (ReadBytes(aOffset, aBuf, aLength) == aLength) ? kErrorNone : kErrorParse;
653 }
654 
CompareBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength,ByteMatcher aMatcher) const655 bool Message::CompareBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength, ByteMatcher aMatcher) const
656 {
657     uint16_t       bytesToCompare = aLength;
658     const uint8_t *bufPtr         = reinterpret_cast<const uint8_t *>(aBuf);
659     Chunk          chunk;
660 
661     GetFirstChunk(aOffset, aLength, chunk);
662 
663     while (chunk.GetLength() > 0)
664     {
665         VerifyOrExit(chunk.MatchesBytesIn(bufPtr, aMatcher));
666         bufPtr += chunk.GetLength();
667         bytesToCompare -= chunk.GetLength();
668         GetNextChunk(aLength, chunk);
669     }
670 
671 exit:
672     return (bytesToCompare == 0);
673 }
674 
CompareBytes(uint16_t aOffset,const Message & aOtherMessage,uint16_t aOtherOffset,uint16_t aLength,ByteMatcher aMatcher) const675 bool Message::CompareBytes(uint16_t       aOffset,
676                            const Message &aOtherMessage,
677                            uint16_t       aOtherOffset,
678                            uint16_t       aLength,
679                            ByteMatcher    aMatcher) const
680 {
681     uint16_t bytesToCompare = aLength;
682     Chunk    chunk;
683 
684     GetFirstChunk(aOffset, aLength, chunk);
685 
686     while (chunk.GetLength() > 0)
687     {
688         VerifyOrExit(aOtherMessage.CompareBytes(aOtherOffset, chunk.GetBytes(), chunk.GetLength(), aMatcher));
689         aOtherOffset += chunk.GetLength();
690         bytesToCompare -= chunk.GetLength();
691         GetNextChunk(aLength, chunk);
692     }
693 
694 exit:
695     return (bytesToCompare == 0);
696 }
697 
WriteBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength)698 void Message::WriteBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength)
699 {
700     const uint8_t *bufPtr = reinterpret_cast<const uint8_t *>(aBuf);
701     MutableChunk   chunk;
702 
703     OT_ASSERT(aOffset + aLength <= GetLength());
704 
705     GetFirstChunk(aOffset, aLength, chunk);
706 
707     while (chunk.GetLength() > 0)
708     {
709         memmove(chunk.GetBytes(), bufPtr, chunk.GetLength());
710         bufPtr += chunk.GetLength();
711         GetNextChunk(aLength, chunk);
712     }
713 }
714 
WriteBytesFromMessage(uint16_t aWriteOffset,const Message & aMessage,uint16_t aReadOffset,uint16_t aLength)715 void Message::WriteBytesFromMessage(uint16_t       aWriteOffset,
716                                     const Message &aMessage,
717                                     uint16_t       aReadOffset,
718                                     uint16_t       aLength)
719 {
720     if ((&aMessage != this) || (aReadOffset >= aWriteOffset))
721     {
722         Chunk chunk;
723 
724         aMessage.GetFirstChunk(aReadOffset, aLength, chunk);
725 
726         while (chunk.GetLength() > 0)
727         {
728             WriteBytes(aWriteOffset, chunk.GetBytes(), chunk.GetLength());
729             aWriteOffset += chunk.GetLength();
730             aMessage.GetNextChunk(aLength, chunk);
731         }
732     }
733     else
734     {
735         // We are copying bytes within the same message forward.
736         // To ensure copy forward works, we read and write from
737         // end of range and move backwards.
738 
739         static constexpr uint16_t kBufSize = 32;
740 
741         uint8_t buf[kBufSize];
742 
743         aWriteOffset += aLength;
744         aReadOffset += aLength;
745 
746         while (aLength > 0)
747         {
748             uint16_t copyLength = Min(kBufSize, aLength);
749 
750             aLength -= copyLength;
751             aReadOffset -= copyLength;
752             aWriteOffset -= copyLength;
753 
754             ReadBytes(aReadOffset, buf, copyLength);
755             WriteBytes(aWriteOffset, buf, copyLength);
756         }
757     }
758 }
759 
Clone(uint16_t aLength) const760 Message *Message::Clone(uint16_t aLength) const
761 {
762     Error    error = kErrorNone;
763     Message *messageCopy;
764     Settings settings(IsLinkSecurityEnabled() ? kWithLinkSecurity : kNoLinkSecurity, GetPriority());
765     uint16_t offset;
766 
767     aLength     = Min(GetLength(), aLength);
768     messageCopy = GetMessagePool()->Allocate(GetType(), GetReserved(), settings);
769     VerifyOrExit(messageCopy != nullptr, error = kErrorNoBufs);
770     SuccessOrExit(error = messageCopy->AppendBytesFromMessage(*this, 0, aLength));
771 
772     // Copy selected message information.
773     offset = Min(GetOffset(), aLength);
774     messageCopy->SetOffset(offset);
775 
776     messageCopy->SetSubType(GetSubType());
777     messageCopy->SetLoopbackToHostAllowed(IsLoopbackToHostAllowed());
778     messageCopy->SetOrigin(GetOrigin());
779 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
780     messageCopy->SetTimeSync(IsTimeSync());
781 #endif
782 
783 exit:
784     FreeAndNullMessageOnError(messageCopy, error);
785     return messageCopy;
786 }
787 
788 #if OPENTHREAD_FTD
GetChildMask(uint16_t aChildIndex) const789 bool Message::GetChildMask(uint16_t aChildIndex) const { return GetMetadata().mChildMask.Get(aChildIndex); }
790 
ClearChildMask(uint16_t aChildIndex)791 void Message::ClearChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, false); }
792 
SetChildMask(uint16_t aChildIndex)793 void Message::SetChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, true); }
794 
IsChildPending(void) const795 bool Message::IsChildPending(void) const { return GetMetadata().mChildMask.HasAny(); }
796 #endif
797 
SetLinkInfo(const ThreadLinkInfo & aLinkInfo)798 void Message::SetLinkInfo(const ThreadLinkInfo &aLinkInfo)
799 {
800     SetLinkSecurityEnabled(aLinkInfo.mLinkSecurity);
801     SetPanId(aLinkInfo.mPanId);
802     AddRss(aLinkInfo.mRss);
803 #if OPENTHREAD_CONFIG_MLE_LINK_METRICS_SUBJECT_ENABLE
804     AddLqi(aLinkInfo.mLqi);
805 #endif
806 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
807     SetTimeSyncSeq(aLinkInfo.mTimeSyncSeq);
808     SetNetworkTimeOffset(aLinkInfo.mNetworkTimeOffset);
809 #endif
810 #if OPENTHREAD_CONFIG_MULTI_RADIO
811     SetRadioType(static_cast<Mac::RadioType>(aLinkInfo.mRadioType));
812 #endif
813 }
814 
IsTimeSync(void) const815 bool Message::IsTimeSync(void) const
816 {
817 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
818     return GetMetadata().mTimeSync;
819 #else
820     return false;
821 #endif
822 }
823 
SetMessageQueue(MessageQueue * aMessageQueue)824 void Message::SetMessageQueue(MessageQueue *aMessageQueue)
825 {
826     GetMetadata().mQueue       = aMessageQueue;
827     GetMetadata().mInPriorityQ = false;
828 }
829 
SetPriorityQueue(PriorityQueue * aPriorityQueue)830 void Message::SetPriorityQueue(PriorityQueue *aPriorityQueue)
831 {
832     GetMetadata().mQueue       = aPriorityQueue;
833     GetMetadata().mInPriorityQ = true;
834 }
835 
836 //---------------------------------------------------------------------------------------------------------------------
837 // MessageQueue
838 
Enqueue(Message & aMessage,QueuePosition aPosition)839 void MessageQueue::Enqueue(Message &aMessage, QueuePosition aPosition)
840 {
841     OT_ASSERT(!aMessage.IsInAQueue());
842     OT_ASSERT((aMessage.Next() == nullptr) && (aMessage.Prev() == nullptr));
843 
844     aMessage.SetMessageQueue(this);
845 
846     if (GetTail() == nullptr)
847     {
848         aMessage.Next() = &aMessage;
849         aMessage.Prev() = &aMessage;
850 
851         SetTail(&aMessage);
852     }
853     else
854     {
855         Message *head = GetTail()->Next();
856 
857         aMessage.Next() = head;
858         aMessage.Prev() = GetTail();
859 
860         head->Prev()      = &aMessage;
861         GetTail()->Next() = &aMessage;
862 
863         if (aPosition == kQueuePositionTail)
864         {
865             SetTail(&aMessage);
866         }
867     }
868 }
869 
Dequeue(Message & aMessage)870 void MessageQueue::Dequeue(Message &aMessage)
871 {
872     OT_ASSERT(aMessage.GetMessageQueue() == this);
873     OT_ASSERT((aMessage.Next() != nullptr) && (aMessage.Prev() != nullptr));
874 
875     if (&aMessage == GetTail())
876     {
877         SetTail(GetTail()->Prev());
878 
879         if (&aMessage == GetTail())
880         {
881             SetTail(nullptr);
882         }
883     }
884 
885     aMessage.Prev()->Next() = aMessage.Next();
886     aMessage.Next()->Prev() = aMessage.Prev();
887 
888     aMessage.Prev() = nullptr;
889     aMessage.Next() = nullptr;
890 
891     aMessage.SetMessageQueue(nullptr);
892 }
893 
DequeueAndFree(Message & aMessage)894 void MessageQueue::DequeueAndFree(Message &aMessage)
895 {
896     Dequeue(aMessage);
897     aMessage.Free();
898 }
899 
DequeueAndFreeAll(void)900 void MessageQueue::DequeueAndFreeAll(void)
901 {
902     Message *message;
903 
904     while ((message = GetHead()) != nullptr)
905     {
906         DequeueAndFree(*message);
907     }
908 }
909 
begin(void)910 Message::Iterator MessageQueue::begin(void) { return Message::Iterator(GetHead()); }
911 
begin(void) const912 Message::ConstIterator MessageQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
913 
GetInfo(Info & aInfo) const914 void MessageQueue::GetInfo(Info &aInfo) const
915 {
916     for (const Message &message : *this)
917     {
918         aInfo.mNumMessages++;
919         aInfo.mNumBuffers += message.GetBufferCount();
920         aInfo.mTotalBytes += message.GetLength();
921     }
922 }
923 
924 //---------------------------------------------------------------------------------------------------------------------
925 // PriorityQueue
926 
FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const927 const Message *PriorityQueue::FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const
928 {
929     // Find the first non-`nullptr` tail starting from the given priority
930     // level and moving forward (wrapping from priority value
931     // `kNumPriorities` -1 back to 0).
932 
933     const Message *tail = nullptr;
934     uint8_t        priority;
935 
936     priority = static_cast<uint8_t>(aStartPriorityLevel);
937 
938     do
939     {
940         if (mTails[priority] != nullptr)
941         {
942             tail = mTails[priority];
943             break;
944         }
945 
946         priority = PrevPriority(priority);
947     } while (priority != aStartPriorityLevel);
948 
949     return tail;
950 }
951 
GetHead(void) const952 const Message *PriorityQueue::GetHead(void) const
953 {
954     return Message::NextOf(FindFirstNonNullTail(Message::kPriorityLow));
955 }
956 
GetHeadForPriority(Message::Priority aPriority) const957 const Message *PriorityQueue::GetHeadForPriority(Message::Priority aPriority) const
958 {
959     const Message *head;
960     const Message *previousTail;
961 
962     if (mTails[aPriority] != nullptr)
963     {
964         previousTail = FindFirstNonNullTail(static_cast<Message::Priority>(PrevPriority(aPriority)));
965 
966         OT_ASSERT(previousTail != nullptr);
967 
968         head = previousTail->Next();
969     }
970     else
971     {
972         head = nullptr;
973     }
974 
975     return head;
976 }
977 
GetTail(void) const978 const Message *PriorityQueue::GetTail(void) const { return FindFirstNonNullTail(Message::kPriorityLow); }
979 
Enqueue(Message & aMessage)980 void PriorityQueue::Enqueue(Message &aMessage)
981 {
982     Message::Priority priority;
983     Message          *tail;
984     Message          *next;
985 
986     OT_ASSERT(!aMessage.IsInAQueue());
987 
988     aMessage.SetPriorityQueue(this);
989 
990     priority = aMessage.GetPriority();
991 
992     tail = FindFirstNonNullTail(priority);
993 
994     if (tail != nullptr)
995     {
996         next = tail->Next();
997 
998         aMessage.Next() = next;
999         aMessage.Prev() = tail;
1000         next->Prev()    = &aMessage;
1001         tail->Next()    = &aMessage;
1002     }
1003     else
1004     {
1005         aMessage.Next() = &aMessage;
1006         aMessage.Prev() = &aMessage;
1007     }
1008 
1009     mTails[priority] = &aMessage;
1010 }
1011 
Dequeue(Message & aMessage)1012 void PriorityQueue::Dequeue(Message &aMessage)
1013 {
1014     Message::Priority priority;
1015     Message          *tail;
1016 
1017     OT_ASSERT(aMessage.GetPriorityQueue() == this);
1018 
1019     priority = aMessage.GetPriority();
1020 
1021     tail = mTails[priority];
1022 
1023     if (&aMessage == tail)
1024     {
1025         tail = tail->Prev();
1026 
1027         if ((&aMessage == tail) || (tail->GetPriority() != priority))
1028         {
1029             tail = nullptr;
1030         }
1031 
1032         mTails[priority] = tail;
1033     }
1034 
1035     aMessage.Next()->Prev() = aMessage.Prev();
1036     aMessage.Prev()->Next() = aMessage.Next();
1037     aMessage.Next()         = nullptr;
1038     aMessage.Prev()         = nullptr;
1039 
1040     aMessage.SetPriorityQueue(nullptr);
1041 }
1042 
DequeueAndFree(Message & aMessage)1043 void PriorityQueue::DequeueAndFree(Message &aMessage)
1044 {
1045     Dequeue(aMessage);
1046     aMessage.Free();
1047 }
1048 
DequeueAndFreeAll(void)1049 void PriorityQueue::DequeueAndFreeAll(void)
1050 {
1051     Message *message;
1052 
1053     while ((message = GetHead()) != nullptr)
1054     {
1055         DequeueAndFree(*message);
1056     }
1057 }
1058 
begin(void)1059 Message::Iterator PriorityQueue::begin(void) { return Message::Iterator(GetHead()); }
1060 
begin(void) const1061 Message::ConstIterator PriorityQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
1062 
GetInfo(Info & aInfo) const1063 void PriorityQueue::GetInfo(Info &aInfo) const
1064 {
1065     for (const Message &message : *this)
1066     {
1067         aInfo.mNumMessages++;
1068         aInfo.mNumBuffers += message.GetBufferCount();
1069         aInfo.mTotalBytes += message.GetLength();
1070     }
1071 }
1072 
1073 } // namespace ot
1074 #endif // OPENTHREAD_MTD || OPENTHREAD_FTD
1075