1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the message buffer pool and message buffers.
32  */
33 
34 #include "message.hpp"
35 
36 #include "common/as_core_type.hpp"
37 #include "common/code_utils.hpp"
38 #include "common/debug.hpp"
39 #include "common/heap.hpp"
40 #include "common/instance.hpp"
41 #include "common/locator_getters.hpp"
42 #include "common/log.hpp"
43 #include "common/num_utils.hpp"
44 #include "common/numeric_limits.hpp"
45 #include "net/checksum.hpp"
46 #include "net/ip6.hpp"
47 
48 #if OPENTHREAD_MTD || OPENTHREAD_FTD
49 
50 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE && OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
51 #error "OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE conflicts with OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT."
52 #endif
53 
54 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE && !OPENTHREAD_CONFIG_DTLS_ENABLE
55 #error "OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE is strongly discouraged when OPENTHREAD_CONFIG_DTLS_ENABLE is off."
56 #endif
57 
58 namespace ot {
59 
60 RegisterLogModule("Message");
61 
62 //---------------------------------------------------------------------------------------------------------------------
63 // MessagePool
64 
MessagePool(Instance & aInstance)65 MessagePool::MessagePool(Instance &aInstance)
66     : InstanceLocator(aInstance)
67     , mNumAllocated(0)
68     , mMaxAllocated(0)
69 {
70 #if OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
71     otPlatMessagePoolInit(&GetInstance(), kNumBuffers, sizeof(Buffer));
72 #endif
73 }
74 
Allocate(Message::Type aType,uint16_t aReserveHeader,const Message::Settings & aSettings)75 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader, const Message::Settings &aSettings)
76 {
77     Error    error = kErrorNone;
78     Message *message;
79 
80     VerifyOrExit((message = static_cast<Message *>(NewBuffer(aSettings.GetPriority()))) != nullptr);
81 
82     memset(message, 0, sizeof(*message));
83     message->SetMessagePool(this);
84     message->SetType(aType);
85     message->SetReserved(aReserveHeader);
86     message->SetLinkSecurityEnabled(aSettings.IsLinkSecurityEnabled());
87 
88     SuccessOrExit(error = message->SetPriority(aSettings.GetPriority()));
89     SuccessOrExit(error = message->SetLength(0));
90 
91 exit:
92     if (error != kErrorNone)
93     {
94         Free(message);
95         message = nullptr;
96     }
97 
98     return message;
99 }
100 
Allocate(Message::Type aType)101 Message *MessagePool::Allocate(Message::Type aType) { return Allocate(aType, 0, Message::Settings::GetDefault()); }
102 
Allocate(Message::Type aType,uint16_t aReserveHeader)103 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader)
104 {
105     return Allocate(aType, aReserveHeader, Message::Settings::GetDefault());
106 }
107 
Free(Message * aMessage)108 void MessagePool::Free(Message *aMessage)
109 {
110     OT_ASSERT(aMessage->Next() == nullptr && aMessage->Prev() == nullptr);
111 
112     FreeBuffers(static_cast<Buffer *>(aMessage));
113 }
114 
NewBuffer(Message::Priority aPriority)115 Buffer *MessagePool::NewBuffer(Message::Priority aPriority)
116 {
117     Buffer *buffer = nullptr;
118 
119     while ((
120 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
121                buffer = static_cast<Buffer *>(Heap::CAlloc(1, sizeof(Buffer)))
122 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
123                buffer = static_cast<Buffer *>(otPlatMessagePoolNew(&GetInstance()))
124 #else
125                buffer = mBufferPool.Allocate()
126 #endif
127                    ) == nullptr)
128     {
129         SuccessOrExit(ReclaimBuffers(aPriority));
130     }
131 
132     mNumAllocated++;
133     mMaxAllocated = Max(mMaxAllocated, mNumAllocated);
134 
135     buffer->SetNextBuffer(nullptr);
136 
137 exit:
138     if (buffer == nullptr)
139     {
140         LogInfo("No available message buffer");
141     }
142 
143     return buffer;
144 }
145 
FreeBuffers(Buffer * aBuffer)146 void MessagePool::FreeBuffers(Buffer *aBuffer)
147 {
148     while (aBuffer != nullptr)
149     {
150         Buffer *next = aBuffer->GetNextBuffer();
151 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
152         Heap::Free(aBuffer);
153 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
154         otPlatMessagePoolFree(&GetInstance(), aBuffer);
155 #else
156         mBufferPool.Free(*aBuffer);
157 #endif
158         mNumAllocated--;
159 
160         aBuffer = next;
161     }
162 }
163 
ReclaimBuffers(Message::Priority aPriority)164 Error MessagePool::ReclaimBuffers(Message::Priority aPriority) { return Get<MeshForwarder>().EvictMessage(aPriority); }
165 
GetFreeBufferCount(void) const166 uint16_t MessagePool::GetFreeBufferCount(void) const
167 {
168     uint16_t rval;
169 
170 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
171 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
172     rval = static_cast<uint16_t>(Instance::GetHeap().GetFreeSize() / sizeof(Buffer));
173 #else
174     rval = NumericLimits<uint16_t>::kMax;
175 #endif
176 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
177     rval = otPlatMessagePoolNumFreeBuffers(&GetInstance());
178 #else
179     rval = kNumBuffers - mNumAllocated;
180 #endif
181 
182     return rval;
183 }
184 
GetTotalBufferCount(void) const185 uint16_t MessagePool::GetTotalBufferCount(void) const
186 {
187     uint16_t rval;
188 
189 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
190 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
191     rval = static_cast<uint16_t>(Instance::GetHeap().GetCapacity() / sizeof(Buffer));
192 #else
193     rval = NumericLimits<uint16_t>::kMax;
194 #endif
195 #else
196     rval = OPENTHREAD_CONFIG_NUM_MESSAGE_BUFFERS;
197 #endif
198 
199     return rval;
200 }
201 
202 //---------------------------------------------------------------------------------------------------------------------
203 // Message::Settings
204 
205 const otMessageSettings Message::Settings::kDefault = {kWithLinkSecurity, kPriorityNormal};
206 
Settings(LinkSecurityMode aSecurityMode,Priority aPriority)207 Message::Settings::Settings(LinkSecurityMode aSecurityMode, Priority aPriority)
208 {
209     mLinkSecurityEnabled = aSecurityMode;
210     mPriority            = aPriority;
211 }
212 
From(const otMessageSettings * aSettings)213 const Message::Settings &Message::Settings::From(const otMessageSettings *aSettings)
214 {
215     return (aSettings == nullptr) ? GetDefault() : AsCoreType(aSettings);
216 }
217 
218 //---------------------------------------------------------------------------------------------------------------------
219 // Message::Iterator
220 
Advance(void)221 void Message::Iterator::Advance(void)
222 {
223     mItem = mNext;
224     mNext = NextMessage(mNext);
225 }
226 
227 //---------------------------------------------------------------------------------------------------------------------
228 // Message
229 
ResizeMessage(uint16_t aLength)230 Error Message::ResizeMessage(uint16_t aLength)
231 {
232     // This method adds or frees message buffers to meet the
233     // requested length.
234 
235     Error    error     = kErrorNone;
236     Buffer  *curBuffer = this;
237     Buffer  *lastBuffer;
238     uint16_t curLength = kHeadBufferDataSize;
239 
240     while (curLength < aLength)
241     {
242         if (curBuffer->GetNextBuffer() == nullptr)
243         {
244             curBuffer->SetNextBuffer(GetMessagePool()->NewBuffer(GetPriority()));
245             VerifyOrExit(curBuffer->GetNextBuffer() != nullptr, error = kErrorNoBufs);
246         }
247 
248         curBuffer = curBuffer->GetNextBuffer();
249         curLength += kBufferDataSize;
250     }
251 
252     lastBuffer = curBuffer;
253     curBuffer  = curBuffer->GetNextBuffer();
254     lastBuffer->SetNextBuffer(nullptr);
255 
256     GetMessagePool()->FreeBuffers(curBuffer);
257 
258 exit:
259     return error;
260 }
261 
Free(void)262 void Message::Free(void) { GetMessagePool()->Free(this); }
263 
GetNext(void) const264 Message *Message::GetNext(void) const
265 {
266     Message *next;
267     Message *tail;
268 
269     if (GetMetadata().mInPriorityQ)
270     {
271         PriorityQueue *priorityQueue = GetPriorityQueue();
272         VerifyOrExit(priorityQueue != nullptr, next = nullptr);
273         tail = priorityQueue->GetTail();
274     }
275     else
276     {
277         MessageQueue *messageQueue = GetMessageQueue();
278         VerifyOrExit(messageQueue != nullptr, next = nullptr);
279         tail = messageQueue->GetTail();
280     }
281 
282     next = (this == tail) ? nullptr : Next();
283 
284 exit:
285     return next;
286 }
287 
SetLength(uint16_t aLength)288 Error Message::SetLength(uint16_t aLength)
289 {
290     Error    error              = kErrorNone;
291     uint16_t totalLengthRequest = GetReserved() + aLength;
292 
293     VerifyOrExit(totalLengthRequest >= GetReserved(), error = kErrorInvalidArgs);
294 
295     SuccessOrExit(error = ResizeMessage(totalLengthRequest));
296     GetMetadata().mLength = aLength;
297 
298     // Correct the offset in case shorter length is set.
299     if (GetOffset() > aLength)
300     {
301         SetOffset(aLength);
302     }
303 
304 exit:
305     return error;
306 }
307 
GetBufferCount(void) const308 uint8_t Message::GetBufferCount(void) const
309 {
310     uint8_t rval = 1;
311 
312     for (const Buffer *curBuffer = GetNextBuffer(); curBuffer; curBuffer = curBuffer->GetNextBuffer())
313     {
314         rval++;
315     }
316 
317     return rval;
318 }
319 
MoveOffset(int aDelta)320 void Message::MoveOffset(int aDelta)
321 {
322     OT_ASSERT(GetOffset() + aDelta <= GetLength());
323     GetMetadata().mOffset += static_cast<int16_t>(aDelta);
324     OT_ASSERT(GetMetadata().mOffset <= GetLength());
325 }
326 
SetOffset(uint16_t aOffset)327 void Message::SetOffset(uint16_t aOffset)
328 {
329     OT_ASSERT(aOffset <= GetLength());
330     GetMetadata().mOffset = aOffset;
331 }
332 
IsSubTypeMle(void) const333 bool Message::IsSubTypeMle(void) const
334 {
335     bool rval;
336 
337     switch (GetMetadata().mSubType)
338     {
339     case kSubTypeMleGeneral:
340     case kSubTypeMleAnnounce:
341     case kSubTypeMleDiscoverRequest:
342     case kSubTypeMleDiscoverResponse:
343     case kSubTypeMleChildUpdateRequest:
344     case kSubTypeMleDataResponse:
345     case kSubTypeMleChildIdRequest:
346         rval = true;
347         break;
348 
349     default:
350         rval = false;
351         break;
352     }
353 
354     return rval;
355 }
356 
SetPriority(Priority aPriority)357 Error Message::SetPriority(Priority aPriority)
358 {
359     Error          error    = kErrorNone;
360     uint8_t        priority = static_cast<uint8_t>(aPriority);
361     PriorityQueue *priorityQueue;
362 
363     static_assert(kNumPriorities <= 4, "`Metadata::mPriority` as a 2-bit field cannot fit all `Priority` values");
364 
365     VerifyOrExit(priority < kNumPriorities, error = kErrorInvalidArgs);
366 
367     VerifyOrExit(IsInAQueue(), GetMetadata().mPriority = priority);
368     VerifyOrExit(GetMetadata().mPriority != priority);
369 
370     priorityQueue = GetPriorityQueue();
371 
372     if (priorityQueue != nullptr)
373     {
374         priorityQueue->Dequeue(*this);
375     }
376 
377     GetMetadata().mPriority = priority;
378 
379     if (priorityQueue != nullptr)
380     {
381         priorityQueue->Enqueue(*this);
382     }
383 
384 exit:
385     return error;
386 }
387 
PriorityToString(Priority aPriority)388 const char *Message::PriorityToString(Priority aPriority)
389 {
390     static const char *const kPriorityStrings[] = {
391         "low",    // (0) kPriorityLow
392         "normal", // (1) kPriorityNormal
393         "high",   // (2) kPriorityHigh
394         "net",    // (3) kPriorityNet
395     };
396 
397     static_assert(kPriorityLow == 0, "kPriorityLow value is incorrect");
398     static_assert(kPriorityNormal == 1, "kPriorityNormal value is incorrect");
399     static_assert(kPriorityHigh == 2, "kPriorityHigh value is incorrect");
400     static_assert(kPriorityNet == 3, "kPriorityNet value is incorrect");
401 
402     return kPriorityStrings[aPriority];
403 }
404 
AppendBytes(const void * aBuf,uint16_t aLength)405 Error Message::AppendBytes(const void *aBuf, uint16_t aLength)
406 {
407     Error    error     = kErrorNone;
408     uint16_t oldLength = GetLength();
409 
410     SuccessOrExit(error = SetLength(GetLength() + aLength));
411     WriteBytes(oldLength, aBuf, aLength);
412 
413 exit:
414     return error;
415 }
416 
AppendBytesFromMessage(const Message & aMessage,uint16_t aOffset,uint16_t aLength)417 Error Message::AppendBytesFromMessage(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
418 {
419     Error    error       = kErrorNone;
420     uint16_t writeOffset = GetLength();
421     Chunk    chunk;
422 
423     VerifyOrExit(aMessage.GetLength() >= aOffset + aLength, error = kErrorParse);
424     SuccessOrExit(error = SetLength(GetLength() + aLength));
425 
426     aMessage.GetFirstChunk(aOffset, aLength, chunk);
427 
428     while (chunk.GetLength() > 0)
429     {
430         WriteBytes(writeOffset, chunk.GetBytes(), chunk.GetLength());
431         writeOffset += chunk.GetLength();
432         aMessage.GetNextChunk(aLength, chunk);
433     }
434 
435 exit:
436     return error;
437 }
438 
PrependBytes(const void * aBuf,uint16_t aLength)439 Error Message::PrependBytes(const void *aBuf, uint16_t aLength)
440 {
441     Error   error     = kErrorNone;
442     Buffer *newBuffer = nullptr;
443 
444     while (aLength > GetReserved())
445     {
446         VerifyOrExit((newBuffer = GetMessagePool()->NewBuffer(GetPriority())) != nullptr, error = kErrorNoBufs);
447 
448         newBuffer->SetNextBuffer(GetNextBuffer());
449         SetNextBuffer(newBuffer);
450 
451         if (GetReserved() < sizeof(mBuffer.mHead.mData))
452         {
453             // Copy payload from the first buffer.
454             memcpy(newBuffer->mBuffer.mHead.mData + GetReserved(), mBuffer.mHead.mData + GetReserved(),
455                    sizeof(mBuffer.mHead.mData) - GetReserved());
456         }
457 
458         SetReserved(GetReserved() + kBufferDataSize);
459     }
460 
461     SetReserved(GetReserved() - aLength);
462     GetMetadata().mLength += aLength;
463     SetOffset(GetOffset() + aLength);
464 
465     if (aBuf != nullptr)
466     {
467         WriteBytes(0, aBuf, aLength);
468     }
469 
470 exit:
471     return error;
472 }
473 
RemoveHeader(uint16_t aLength)474 void Message::RemoveHeader(uint16_t aLength)
475 {
476     OT_ASSERT(aLength <= GetMetadata().mLength);
477 
478     GetMetadata().mReserved += aLength;
479     GetMetadata().mLength -= aLength;
480 
481     if (GetMetadata().mOffset > aLength)
482     {
483         GetMetadata().mOffset -= aLength;
484     }
485     else
486     {
487         GetMetadata().mOffset = 0;
488     }
489 }
490 
RemoveHeader(uint16_t aOffset,uint16_t aLength)491 void Message::RemoveHeader(uint16_t aOffset, uint16_t aLength)
492 {
493     // To shrink the header, we copy the header byte before `aOffset`
494     // forward. Starting at offset `aLength`, we write bytes we read
495     // from offset `0` onward and copy a total of `aOffset` bytes.
496     // Then remove the first `aLength` bytes from message.
497     //
498     //
499     // 0                   aOffset  aOffset + aLength
500     // +-----------------------+---------+------------------------+
501     // | / / / / / / / / / / / | x x x x |                        |
502     // +-----------------------+---------+------------------------+
503     //
504     // 0       aLength                aOffset + aLength
505     // +---------+-----------------------+------------------------+
506     // |         | / / / / / / / / / / / |                        |
507     // +---------+-----------------------+------------------------+
508     //
509     //  0                    aOffset
510     //  +-----------------------+------------------------+
511     //  | / / / / / / / / / / / |                        |
512     //  +-----------------------+------------------------+
513     //
514 
515     WriteBytesFromMessage(/* aWriteOffset */ aLength, *this, /* aReadOffset */ 0, /* aLength */ aOffset);
516     RemoveHeader(aLength);
517 }
518 
InsertHeader(uint16_t aOffset,uint16_t aLength)519 Error Message::InsertHeader(uint16_t aOffset, uint16_t aLength)
520 {
521     Error error;
522 
523     // To make space in header at `aOffset`, we first prepend
524     // `aLength` bytes at front. Then copy the existing bytes
525     // backwards. Starting at offset `0`, we write bytes we read
526     // from offset `aLength` onward and copy a total of `aOffset`
527     // bytes.
528     //
529     // 0                    aOffset
530     // +-----------------------+------------------------+
531     // | / / / / / / / / / / / |                        |
532     // +-----------------------+------------------------+
533     //
534     // 0       aLength                aOffset + aLength
535     // +---------+-----------------------+------------------------+
536     // |         | / / / / / / / / / / / |                        |
537     // +---------+-----------------------+------------------------+
538     //
539     // 0                   aOffset  aOffset + aLength
540     // +-----------------------+---------+------------------------+
541     // | / / / / / / / / / / / |  N E W  |                        |
542     // +-----------------------+---------+------------------------+
543     //
544 
545     SuccessOrExit(error = PrependBytes(nullptr, aLength));
546     WriteBytesFromMessage(/* aWriteOffset */ 0, *this, /* aReadOffset */ aLength, /* aLength */ aOffset);
547 
548 exit:
549     return error;
550 }
551 
GetFirstChunk(uint16_t aOffset,uint16_t & aLength,Chunk & aChunk) const552 void Message::GetFirstChunk(uint16_t aOffset, uint16_t &aLength, Chunk &aChunk) const
553 {
554     // This method gets the first message chunk (contiguous data
555     // buffer) corresponding to a given offset and length. On exit
556     // `aChunk` is updated such that `aChunk.GetBytes()` gives the
557     // pointer to the start of chunk and `aChunk.GetLength()` gives
558     // its length. The `aLength` is also decreased by the chunk
559     // length.
560 
561     VerifyOrExit(aOffset < GetLength(), aChunk.SetLength(0));
562 
563     if (aOffset + aLength >= GetLength())
564     {
565         aLength = GetLength() - aOffset;
566     }
567 
568     aOffset += GetReserved();
569 
570     aChunk.SetBuffer(this);
571 
572     // Special case for the first buffer
573 
574     if (aOffset < kHeadBufferDataSize)
575     {
576         aChunk.Init(GetFirstData() + aOffset, kHeadBufferDataSize - aOffset);
577         ExitNow();
578     }
579 
580     aOffset -= kHeadBufferDataSize;
581 
582     // Find the `Buffer` matching the offset
583 
584     while (true)
585     {
586         aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
587 
588         OT_ASSERT(aChunk.GetBuffer() != nullptr);
589 
590         if (aOffset < kBufferDataSize)
591         {
592             aChunk.Init(aChunk.GetBuffer()->GetData() + aOffset, kBufferDataSize - aOffset);
593             ExitNow();
594         }
595 
596         aOffset -= kBufferDataSize;
597     }
598 
599 exit:
600     if (aChunk.GetLength() > aLength)
601     {
602         aChunk.SetLength(aLength);
603     }
604 
605     aLength -= aChunk.GetLength();
606 }
607 
GetNextChunk(uint16_t & aLength,Chunk & aChunk) const608 void Message::GetNextChunk(uint16_t &aLength, Chunk &aChunk) const
609 {
610     // This method gets the next message chunk. On input, the
611     // `aChunk` should be the previous chunk. On exit, it is
612     // updated to provide info about next chunk, and `aLength`
613     // is decreased by the chunk length. If there is no more
614     // chunk, `aChunk.GetLength()` would be zero.
615 
616     VerifyOrExit(aLength > 0, aChunk.SetLength(0));
617 
618     aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
619 
620     OT_ASSERT(aChunk.GetBuffer() != nullptr);
621 
622     aChunk.Init(aChunk.GetBuffer()->GetData(), kBufferDataSize);
623 
624     if (aChunk.GetLength() > aLength)
625     {
626         aChunk.SetLength(aLength);
627     }
628 
629     aLength -= aChunk.GetLength();
630 
631 exit:
632     return;
633 }
634 
ReadBytes(uint16_t aOffset,void * aBuf,uint16_t aLength) const635 uint16_t Message::ReadBytes(uint16_t aOffset, void *aBuf, uint16_t aLength) const
636 {
637     uint8_t *bufPtr = reinterpret_cast<uint8_t *>(aBuf);
638     Chunk    chunk;
639 
640     GetFirstChunk(aOffset, aLength, chunk);
641 
642     while (chunk.GetLength() > 0)
643     {
644         chunk.CopyBytesTo(bufPtr);
645         bufPtr += chunk.GetLength();
646         GetNextChunk(aLength, chunk);
647     }
648 
649     return static_cast<uint16_t>(bufPtr - reinterpret_cast<uint8_t *>(aBuf));
650 }
651 
Read(uint16_t aOffset,void * aBuf,uint16_t aLength) const652 Error Message::Read(uint16_t aOffset, void *aBuf, uint16_t aLength) const
653 {
654     return (ReadBytes(aOffset, aBuf, aLength) == aLength) ? kErrorNone : kErrorParse;
655 }
656 
CompareBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength,ByteMatcher aMatcher) const657 bool Message::CompareBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength, ByteMatcher aMatcher) const
658 {
659     uint16_t       bytesToCompare = aLength;
660     const uint8_t *bufPtr         = reinterpret_cast<const uint8_t *>(aBuf);
661     Chunk          chunk;
662 
663     GetFirstChunk(aOffset, aLength, chunk);
664 
665     while (chunk.GetLength() > 0)
666     {
667         VerifyOrExit(chunk.MatchesBytesIn(bufPtr, aMatcher));
668         bufPtr += chunk.GetLength();
669         bytesToCompare -= chunk.GetLength();
670         GetNextChunk(aLength, chunk);
671     }
672 
673 exit:
674     return (bytesToCompare == 0);
675 }
676 
CompareBytes(uint16_t aOffset,const Message & aOtherMessage,uint16_t aOtherOffset,uint16_t aLength,ByteMatcher aMatcher) const677 bool Message::CompareBytes(uint16_t       aOffset,
678                            const Message &aOtherMessage,
679                            uint16_t       aOtherOffset,
680                            uint16_t       aLength,
681                            ByteMatcher    aMatcher) const
682 {
683     uint16_t bytesToCompare = aLength;
684     Chunk    chunk;
685 
686     GetFirstChunk(aOffset, aLength, chunk);
687 
688     while (chunk.GetLength() > 0)
689     {
690         VerifyOrExit(aOtherMessage.CompareBytes(aOtherOffset, chunk.GetBytes(), chunk.GetLength(), aMatcher));
691         aOtherOffset += chunk.GetLength();
692         bytesToCompare -= chunk.GetLength();
693         GetNextChunk(aLength, chunk);
694     }
695 
696 exit:
697     return (bytesToCompare == 0);
698 }
699 
WriteBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength)700 void Message::WriteBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength)
701 {
702     const uint8_t *bufPtr = reinterpret_cast<const uint8_t *>(aBuf);
703     MutableChunk   chunk;
704 
705     OT_ASSERT(aOffset + aLength <= GetLength());
706 
707     GetFirstChunk(aOffset, aLength, chunk);
708 
709     while (chunk.GetLength() > 0)
710     {
711         memmove(chunk.GetBytes(), bufPtr, chunk.GetLength());
712         bufPtr += chunk.GetLength();
713         GetNextChunk(aLength, chunk);
714     }
715 }
716 
WriteBytesFromMessage(uint16_t aWriteOffset,const Message & aMessage,uint16_t aReadOffset,uint16_t aLength)717 void Message::WriteBytesFromMessage(uint16_t       aWriteOffset,
718                                     const Message &aMessage,
719                                     uint16_t       aReadOffset,
720                                     uint16_t       aLength)
721 {
722     if ((&aMessage != this) || (aReadOffset >= aWriteOffset))
723     {
724         Chunk chunk;
725 
726         aMessage.GetFirstChunk(aReadOffset, aLength, chunk);
727 
728         while (chunk.GetLength() > 0)
729         {
730             WriteBytes(aWriteOffset, chunk.GetBytes(), chunk.GetLength());
731             aWriteOffset += chunk.GetLength();
732             aMessage.GetNextChunk(aLength, chunk);
733         }
734     }
735     else
736     {
737         // We are copying bytes within the same message forward.
738         // To ensure copy forward works, we read and write from
739         // end of range and move backwards.
740 
741         static constexpr uint16_t kBufSize = 32;
742 
743         uint8_t buf[kBufSize];
744 
745         aWriteOffset += aLength;
746         aReadOffset += aLength;
747 
748         while (aLength > 0)
749         {
750             uint16_t copyLength = Min(kBufSize, aLength);
751 
752             aLength -= copyLength;
753             aReadOffset -= copyLength;
754             aWriteOffset -= copyLength;
755 
756             ReadBytes(aReadOffset, buf, copyLength);
757             WriteBytes(aWriteOffset, buf, copyLength);
758         }
759     }
760 }
761 
Clone(uint16_t aLength) const762 Message *Message::Clone(uint16_t aLength) const
763 {
764     Error    error = kErrorNone;
765     Message *messageCopy;
766     Settings settings(IsLinkSecurityEnabled() ? kWithLinkSecurity : kNoLinkSecurity, GetPriority());
767     uint16_t offset;
768 
769     aLength     = Min(GetLength(), aLength);
770     messageCopy = GetMessagePool()->Allocate(GetType(), GetReserved(), settings);
771     VerifyOrExit(messageCopy != nullptr, error = kErrorNoBufs);
772     SuccessOrExit(error = messageCopy->AppendBytesFromMessage(*this, 0, aLength));
773 
774     // Copy selected message information.
775     offset = Min(GetOffset(), aLength);
776     messageCopy->SetOffset(offset);
777 
778     messageCopy->SetSubType(GetSubType());
779 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
780     messageCopy->SetTimeSync(IsTimeSync());
781 #endif
782 
783 exit:
784     FreeAndNullMessageOnError(messageCopy, error);
785     return messageCopy;
786 }
787 
GetChildMask(uint16_t aChildIndex) const788 bool Message::GetChildMask(uint16_t aChildIndex) const { return GetMetadata().mChildMask.Get(aChildIndex); }
789 
ClearChildMask(uint16_t aChildIndex)790 void Message::ClearChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, false); }
791 
SetChildMask(uint16_t aChildIndex)792 void Message::SetChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, true); }
793 
IsChildPending(void) const794 bool Message::IsChildPending(void) const { return GetMetadata().mChildMask.HasAny(); }
795 
SetLinkInfo(const ThreadLinkInfo & aLinkInfo)796 void Message::SetLinkInfo(const ThreadLinkInfo &aLinkInfo)
797 {
798     SetLinkSecurityEnabled(aLinkInfo.mLinkSecurity);
799     SetPanId(aLinkInfo.mPanId);
800     AddRss(aLinkInfo.mRss);
801 #if OPENTHREAD_CONFIG_MLE_LINK_METRICS_SUBJECT_ENABLE
802     AddLqi(aLinkInfo.mLqi);
803 #endif
804 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
805     SetTimeSyncSeq(aLinkInfo.mTimeSyncSeq);
806     SetNetworkTimeOffset(aLinkInfo.mNetworkTimeOffset);
807 #endif
808 #if OPENTHREAD_CONFIG_MULTI_RADIO
809     SetRadioType(static_cast<Mac::RadioType>(aLinkInfo.mRadioType));
810 #endif
811 }
812 
IsTimeSync(void) const813 bool Message::IsTimeSync(void) const
814 {
815 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
816     return GetMetadata().mTimeSync;
817 #else
818     return false;
819 #endif
820 }
821 
SetMessageQueue(MessageQueue * aMessageQueue)822 void Message::SetMessageQueue(MessageQueue *aMessageQueue)
823 {
824     GetMetadata().mQueue       = aMessageQueue;
825     GetMetadata().mInPriorityQ = false;
826 }
827 
SetPriorityQueue(PriorityQueue * aPriorityQueue)828 void Message::SetPriorityQueue(PriorityQueue *aPriorityQueue)
829 {
830     GetMetadata().mQueue       = aPriorityQueue;
831     GetMetadata().mInPriorityQ = true;
832 }
833 
834 //---------------------------------------------------------------------------------------------------------------------
835 // MessageQueue
836 
Enqueue(Message & aMessage,QueuePosition aPosition)837 void MessageQueue::Enqueue(Message &aMessage, QueuePosition aPosition)
838 {
839     OT_ASSERT(!aMessage.IsInAQueue());
840     OT_ASSERT((aMessage.Next() == nullptr) && (aMessage.Prev() == nullptr));
841 
842     aMessage.SetMessageQueue(this);
843 
844     if (GetTail() == nullptr)
845     {
846         aMessage.Next() = &aMessage;
847         aMessage.Prev() = &aMessage;
848 
849         SetTail(&aMessage);
850     }
851     else
852     {
853         Message *head = GetTail()->Next();
854 
855         aMessage.Next() = head;
856         aMessage.Prev() = GetTail();
857 
858         head->Prev()      = &aMessage;
859         GetTail()->Next() = &aMessage;
860 
861         if (aPosition == kQueuePositionTail)
862         {
863             SetTail(&aMessage);
864         }
865     }
866 }
867 
Dequeue(Message & aMessage)868 void MessageQueue::Dequeue(Message &aMessage)
869 {
870     OT_ASSERT(aMessage.GetMessageQueue() == this);
871     OT_ASSERT((aMessage.Next() != nullptr) && (aMessage.Prev() != nullptr));
872 
873     if (&aMessage == GetTail())
874     {
875         SetTail(GetTail()->Prev());
876 
877         if (&aMessage == GetTail())
878         {
879             SetTail(nullptr);
880         }
881     }
882 
883     aMessage.Prev()->Next() = aMessage.Next();
884     aMessage.Next()->Prev() = aMessage.Prev();
885 
886     aMessage.Prev() = nullptr;
887     aMessage.Next() = nullptr;
888 
889     aMessage.SetMessageQueue(nullptr);
890 }
891 
DequeueAndFree(Message & aMessage)892 void MessageQueue::DequeueAndFree(Message &aMessage)
893 {
894     Dequeue(aMessage);
895     aMessage.Free();
896 }
897 
DequeueAndFreeAll(void)898 void MessageQueue::DequeueAndFreeAll(void)
899 {
900     Message *message;
901 
902     while ((message = GetHead()) != nullptr)
903     {
904         DequeueAndFree(*message);
905     }
906 }
907 
begin(void)908 Message::Iterator MessageQueue::begin(void) { return Message::Iterator(GetHead()); }
909 
begin(void) const910 Message::ConstIterator MessageQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
911 
GetInfo(Info & aInfo) const912 void MessageQueue::GetInfo(Info &aInfo) const
913 {
914     for (const Message &message : *this)
915     {
916         aInfo.mNumMessages++;
917         aInfo.mNumBuffers += message.GetBufferCount();
918         aInfo.mTotalBytes += message.GetLength();
919     }
920 }
921 
922 //---------------------------------------------------------------------------------------------------------------------
923 // PriorityQueue
924 
FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const925 const Message *PriorityQueue::FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const
926 {
927     // Find the first non-`nullptr` tail starting from the given priority
928     // level and moving forward (wrapping from priority value
929     // `kNumPriorities` -1 back to 0).
930 
931     const Message *tail = nullptr;
932     uint8_t        priority;
933 
934     priority = static_cast<uint8_t>(aStartPriorityLevel);
935 
936     do
937     {
938         if (mTails[priority] != nullptr)
939         {
940             tail = mTails[priority];
941             break;
942         }
943 
944         priority = PrevPriority(priority);
945     } while (priority != aStartPriorityLevel);
946 
947     return tail;
948 }
949 
GetHead(void) const950 const Message *PriorityQueue::GetHead(void) const
951 {
952     return Message::NextOf(FindFirstNonNullTail(Message::kPriorityLow));
953 }
954 
GetHeadForPriority(Message::Priority aPriority) const955 const Message *PriorityQueue::GetHeadForPriority(Message::Priority aPriority) const
956 {
957     const Message *head;
958     const Message *previousTail;
959 
960     if (mTails[aPriority] != nullptr)
961     {
962         previousTail = FindFirstNonNullTail(static_cast<Message::Priority>(PrevPriority(aPriority)));
963 
964         OT_ASSERT(previousTail != nullptr);
965 
966         head = previousTail->Next();
967     }
968     else
969     {
970         head = nullptr;
971     }
972 
973     return head;
974 }
975 
GetTail(void) const976 const Message *PriorityQueue::GetTail(void) const { return FindFirstNonNullTail(Message::kPriorityLow); }
977 
Enqueue(Message & aMessage)978 void PriorityQueue::Enqueue(Message &aMessage)
979 {
980     Message::Priority priority;
981     Message          *tail;
982     Message          *next;
983 
984     OT_ASSERT(!aMessage.IsInAQueue());
985 
986     aMessage.SetPriorityQueue(this);
987 
988     priority = aMessage.GetPriority();
989 
990     tail = FindFirstNonNullTail(priority);
991 
992     if (tail != nullptr)
993     {
994         next = tail->Next();
995 
996         aMessage.Next() = next;
997         aMessage.Prev() = tail;
998         next->Prev()    = &aMessage;
999         tail->Next()    = &aMessage;
1000     }
1001     else
1002     {
1003         aMessage.Next() = &aMessage;
1004         aMessage.Prev() = &aMessage;
1005     }
1006 
1007     mTails[priority] = &aMessage;
1008 }
1009 
Dequeue(Message & aMessage)1010 void PriorityQueue::Dequeue(Message &aMessage)
1011 {
1012     Message::Priority priority;
1013     Message          *tail;
1014 
1015     OT_ASSERT(aMessage.GetPriorityQueue() == this);
1016 
1017     priority = aMessage.GetPriority();
1018 
1019     tail = mTails[priority];
1020 
1021     if (&aMessage == tail)
1022     {
1023         tail = tail->Prev();
1024 
1025         if ((&aMessage == tail) || (tail->GetPriority() != priority))
1026         {
1027             tail = nullptr;
1028         }
1029 
1030         mTails[priority] = tail;
1031     }
1032 
1033     aMessage.Next()->Prev() = aMessage.Prev();
1034     aMessage.Prev()->Next() = aMessage.Next();
1035     aMessage.Next()         = nullptr;
1036     aMessage.Prev()         = nullptr;
1037 
1038     aMessage.SetPriorityQueue(nullptr);
1039 }
1040 
DequeueAndFree(Message & aMessage)1041 void PriorityQueue::DequeueAndFree(Message &aMessage)
1042 {
1043     Dequeue(aMessage);
1044     aMessage.Free();
1045 }
1046 
DequeueAndFreeAll(void)1047 void PriorityQueue::DequeueAndFreeAll(void)
1048 {
1049     Message *message;
1050 
1051     while ((message = GetHead()) != nullptr)
1052     {
1053         DequeueAndFree(*message);
1054     }
1055 }
1056 
begin(void)1057 Message::Iterator PriorityQueue::begin(void) { return Message::Iterator(GetHead()); }
1058 
begin(void) const1059 Message::ConstIterator PriorityQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
1060 
GetInfo(Info & aInfo) const1061 void PriorityQueue::GetInfo(Info &aInfo) const
1062 {
1063     for (const Message &message : *this)
1064     {
1065         aInfo.mNumMessages++;
1066         aInfo.mNumBuffers += message.GetBufferCount();
1067         aInfo.mTotalBytes += message.GetLength();
1068     }
1069 }
1070 
1071 } // namespace ot
1072 #endif // OPENTHREAD_MTD || OPENTHREAD_FTD
1073