1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the message buffer pool and message buffers.
32  */
33 
34 #include "message.hpp"
35 
36 #include "common/as_core_type.hpp"
37 #include "common/code_utils.hpp"
38 #include "common/debug.hpp"
39 #include "common/heap.hpp"
40 #include "common/instance.hpp"
41 #include "common/locator_getters.hpp"
42 #include "common/log.hpp"
43 #include "common/num_utils.hpp"
44 #include "common/numeric_limits.hpp"
45 #include "net/checksum.hpp"
46 #include "net/ip6.hpp"
47 
48 #if OPENTHREAD_MTD || OPENTHREAD_FTD
49 
50 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE && OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
51 #error "OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE conflicts with OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT."
52 #endif
53 
54 namespace ot {
55 
56 RegisterLogModule("Message");
57 
58 //---------------------------------------------------------------------------------------------------------------------
59 // MessagePool
60 
MessagePool(Instance & aInstance)61 MessagePool::MessagePool(Instance &aInstance)
62     : InstanceLocator(aInstance)
63     , mNumAllocated(0)
64     , mMaxAllocated(0)
65 {
66 #if OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
67     otPlatMessagePoolInit(&GetInstance(), kNumBuffers, sizeof(Buffer));
68 #endif
69 }
70 
Allocate(Message::Type aType,uint16_t aReserveHeader,const Message::Settings & aSettings)71 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader, const Message::Settings &aSettings)
72 {
73     Error    error = kErrorNone;
74     Message *message;
75 
76     VerifyOrExit((message = static_cast<Message *>(NewBuffer(aSettings.GetPriority()))) != nullptr);
77 
78     memset(message, 0, sizeof(*message));
79     message->SetMessagePool(this);
80     message->SetType(aType);
81     message->SetReserved(aReserveHeader);
82     message->SetLinkSecurityEnabled(aSettings.IsLinkSecurityEnabled());
83 
84     SuccessOrExit(error = message->SetPriority(aSettings.GetPriority()));
85     SuccessOrExit(error = message->SetLength(0));
86 
87 exit:
88     if (error != kErrorNone)
89     {
90         Free(message);
91         message = nullptr;
92     }
93 
94     return message;
95 }
96 
Allocate(Message::Type aType)97 Message *MessagePool::Allocate(Message::Type aType) { return Allocate(aType, 0, Message::Settings::GetDefault()); }
98 
Allocate(Message::Type aType,uint16_t aReserveHeader)99 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader)
100 {
101     return Allocate(aType, aReserveHeader, Message::Settings::GetDefault());
102 }
103 
Free(Message * aMessage)104 void MessagePool::Free(Message *aMessage)
105 {
106     OT_ASSERT(aMessage->Next() == nullptr && aMessage->Prev() == nullptr);
107 
108     FreeBuffers(static_cast<Buffer *>(aMessage));
109 }
110 
NewBuffer(Message::Priority aPriority)111 Buffer *MessagePool::NewBuffer(Message::Priority aPriority)
112 {
113     Buffer *buffer = nullptr;
114 
115     while ((
116 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
117                buffer = static_cast<Buffer *>(Heap::CAlloc(1, sizeof(Buffer)))
118 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
119                buffer = static_cast<Buffer *>(otPlatMessagePoolNew(&GetInstance()))
120 #else
121                buffer = mBufferPool.Allocate()
122 #endif
123                    ) == nullptr)
124     {
125         SuccessOrExit(ReclaimBuffers(aPriority));
126     }
127 
128     mNumAllocated++;
129     mMaxAllocated = Max(mMaxAllocated, mNumAllocated);
130 
131     buffer->SetNextBuffer(nullptr);
132 
133 exit:
134     if (buffer == nullptr)
135     {
136         LogInfo("No available message buffer");
137     }
138 
139     return buffer;
140 }
141 
FreeBuffers(Buffer * aBuffer)142 void MessagePool::FreeBuffers(Buffer *aBuffer)
143 {
144     while (aBuffer != nullptr)
145     {
146         Buffer *next = aBuffer->GetNextBuffer();
147 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
148         Heap::Free(aBuffer);
149 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
150         otPlatMessagePoolFree(&GetInstance(), aBuffer);
151 #else
152         mBufferPool.Free(*aBuffer);
153 #endif
154         mNumAllocated--;
155 
156         aBuffer = next;
157     }
158 }
159 
ReclaimBuffers(Message::Priority aPriority)160 Error MessagePool::ReclaimBuffers(Message::Priority aPriority) { return Get<MeshForwarder>().EvictMessage(aPriority); }
161 
GetFreeBufferCount(void) const162 uint16_t MessagePool::GetFreeBufferCount(void) const
163 {
164     uint16_t rval;
165 
166 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
167 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
168     rval = static_cast<uint16_t>(Instance::GetHeap().GetFreeSize() / sizeof(Buffer));
169 #else
170     rval = NumericLimits<uint16_t>::kMax;
171 #endif
172 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
173     rval = otPlatMessagePoolNumFreeBuffers(&GetInstance());
174 #else
175     rval = kNumBuffers - mNumAllocated;
176 #endif
177 
178     return rval;
179 }
180 
GetTotalBufferCount(void) const181 uint16_t MessagePool::GetTotalBufferCount(void) const
182 {
183     uint16_t rval;
184 
185 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
186 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
187     rval = static_cast<uint16_t>(Instance::GetHeap().GetCapacity() / sizeof(Buffer));
188 #else
189     rval = NumericLimits<uint16_t>::kMax;
190 #endif
191 #else
192     rval = OPENTHREAD_CONFIG_NUM_MESSAGE_BUFFERS;
193 #endif
194 
195     return rval;
196 }
197 
198 //---------------------------------------------------------------------------------------------------------------------
199 // Message::Settings
200 
201 const otMessageSettings Message::Settings::kDefault = {kWithLinkSecurity, kPriorityNormal};
202 
Settings(LinkSecurityMode aSecurityMode,Priority aPriority)203 Message::Settings::Settings(LinkSecurityMode aSecurityMode, Priority aPriority)
204 {
205     mLinkSecurityEnabled = aSecurityMode;
206     mPriority            = aPriority;
207 }
208 
From(const otMessageSettings * aSettings)209 const Message::Settings &Message::Settings::From(const otMessageSettings *aSettings)
210 {
211     return (aSettings == nullptr) ? GetDefault() : AsCoreType(aSettings);
212 }
213 
214 //---------------------------------------------------------------------------------------------------------------------
215 // Message::Iterator
216 
Advance(void)217 void Message::Iterator::Advance(void)
218 {
219     mItem = mNext;
220     mNext = NextMessage(mNext);
221 }
222 
223 //---------------------------------------------------------------------------------------------------------------------
224 // Message
225 
ResizeMessage(uint16_t aLength)226 Error Message::ResizeMessage(uint16_t aLength)
227 {
228     // This method adds or frees message buffers to meet the
229     // requested length.
230 
231     Error    error     = kErrorNone;
232     Buffer  *curBuffer = this;
233     Buffer  *lastBuffer;
234     uint16_t curLength = kHeadBufferDataSize;
235 
236     while (curLength < aLength)
237     {
238         if (curBuffer->GetNextBuffer() == nullptr)
239         {
240             curBuffer->SetNextBuffer(GetMessagePool()->NewBuffer(GetPriority()));
241             VerifyOrExit(curBuffer->GetNextBuffer() != nullptr, error = kErrorNoBufs);
242         }
243 
244         curBuffer = curBuffer->GetNextBuffer();
245         curLength += kBufferDataSize;
246     }
247 
248     lastBuffer = curBuffer;
249     curBuffer  = curBuffer->GetNextBuffer();
250     lastBuffer->SetNextBuffer(nullptr);
251 
252     GetMessagePool()->FreeBuffers(curBuffer);
253 
254 exit:
255     return error;
256 }
257 
Free(void)258 void Message::Free(void) { GetMessagePool()->Free(this); }
259 
GetNext(void) const260 Message *Message::GetNext(void) const
261 {
262     Message *next;
263     Message *tail;
264 
265     if (GetMetadata().mInPriorityQ)
266     {
267         PriorityQueue *priorityQueue = GetPriorityQueue();
268         VerifyOrExit(priorityQueue != nullptr, next = nullptr);
269         tail = priorityQueue->GetTail();
270     }
271     else
272     {
273         MessageQueue *messageQueue = GetMessageQueue();
274         VerifyOrExit(messageQueue != nullptr, next = nullptr);
275         tail = messageQueue->GetTail();
276     }
277 
278     next = (this == tail) ? nullptr : Next();
279 
280 exit:
281     return next;
282 }
283 
SetLength(uint16_t aLength)284 Error Message::SetLength(uint16_t aLength)
285 {
286     Error    error              = kErrorNone;
287     uint16_t totalLengthRequest = GetReserved() + aLength;
288 
289     VerifyOrExit(totalLengthRequest >= GetReserved(), error = kErrorInvalidArgs);
290 
291     SuccessOrExit(error = ResizeMessage(totalLengthRequest));
292     GetMetadata().mLength = aLength;
293 
294     // Correct the offset in case shorter length is set.
295     if (GetOffset() > aLength)
296     {
297         SetOffset(aLength);
298     }
299 
300 exit:
301     return error;
302 }
303 
GetBufferCount(void) const304 uint8_t Message::GetBufferCount(void) const
305 {
306     uint8_t rval = 1;
307 
308     for (const Buffer *curBuffer = GetNextBuffer(); curBuffer; curBuffer = curBuffer->GetNextBuffer())
309     {
310         rval++;
311     }
312 
313     return rval;
314 }
315 
MoveOffset(int aDelta)316 void Message::MoveOffset(int aDelta)
317 {
318     OT_ASSERT(GetOffset() + aDelta <= GetLength());
319     GetMetadata().mOffset += static_cast<int16_t>(aDelta);
320     OT_ASSERT(GetMetadata().mOffset <= GetLength());
321 }
322 
SetOffset(uint16_t aOffset)323 void Message::SetOffset(uint16_t aOffset)
324 {
325     OT_ASSERT(aOffset <= GetLength());
326     GetMetadata().mOffset = aOffset;
327 }
328 
IsSubTypeMle(void) const329 bool Message::IsSubTypeMle(void) const
330 {
331     bool rval;
332 
333     switch (GetMetadata().mSubType)
334     {
335     case kSubTypeMleGeneral:
336     case kSubTypeMleAnnounce:
337     case kSubTypeMleDiscoverRequest:
338     case kSubTypeMleDiscoverResponse:
339     case kSubTypeMleChildUpdateRequest:
340     case kSubTypeMleDataResponse:
341     case kSubTypeMleChildIdRequest:
342         rval = true;
343         break;
344 
345     default:
346         rval = false;
347         break;
348     }
349 
350     return rval;
351 }
352 
SetPriority(Priority aPriority)353 Error Message::SetPriority(Priority aPriority)
354 {
355     Error          error    = kErrorNone;
356     uint8_t        priority = static_cast<uint8_t>(aPriority);
357     PriorityQueue *priorityQueue;
358 
359     static_assert(kNumPriorities <= 4, "`Metadata::mPriority` as a 2-bit field cannot fit all `Priority` values");
360 
361     VerifyOrExit(priority < kNumPriorities, error = kErrorInvalidArgs);
362 
363     VerifyOrExit(IsInAQueue(), GetMetadata().mPriority = priority);
364     VerifyOrExit(GetMetadata().mPriority != priority);
365 
366     priorityQueue = GetPriorityQueue();
367 
368     if (priorityQueue != nullptr)
369     {
370         priorityQueue->Dequeue(*this);
371     }
372 
373     GetMetadata().mPriority = priority;
374 
375     if (priorityQueue != nullptr)
376     {
377         priorityQueue->Enqueue(*this);
378     }
379 
380 exit:
381     return error;
382 }
383 
PriorityToString(Priority aPriority)384 const char *Message::PriorityToString(Priority aPriority)
385 {
386     static const char *const kPriorityStrings[] = {
387         "low",    // (0) kPriorityLow
388         "normal", // (1) kPriorityNormal
389         "high",   // (2) kPriorityHigh
390         "net",    // (3) kPriorityNet
391     };
392 
393     static_assert(kPriorityLow == 0, "kPriorityLow value is incorrect");
394     static_assert(kPriorityNormal == 1, "kPriorityNormal value is incorrect");
395     static_assert(kPriorityHigh == 2, "kPriorityHigh value is incorrect");
396     static_assert(kPriorityNet == 3, "kPriorityNet value is incorrect");
397 
398     return kPriorityStrings[aPriority];
399 }
400 
AppendBytes(const void * aBuf,uint16_t aLength)401 Error Message::AppendBytes(const void *aBuf, uint16_t aLength)
402 {
403     Error    error     = kErrorNone;
404     uint16_t oldLength = GetLength();
405 
406     SuccessOrExit(error = SetLength(GetLength() + aLength));
407     WriteBytes(oldLength, aBuf, aLength);
408 
409 exit:
410     return error;
411 }
412 
AppendBytesFromMessage(const Message & aMessage,uint16_t aOffset,uint16_t aLength)413 Error Message::AppendBytesFromMessage(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
414 {
415     Error    error       = kErrorNone;
416     uint16_t writeOffset = GetLength();
417     Chunk    chunk;
418 
419     VerifyOrExit(aMessage.GetLength() >= aOffset + aLength, error = kErrorParse);
420     SuccessOrExit(error = SetLength(GetLength() + aLength));
421 
422     aMessage.GetFirstChunk(aOffset, aLength, chunk);
423 
424     while (chunk.GetLength() > 0)
425     {
426         WriteBytes(writeOffset, chunk.GetBytes(), chunk.GetLength());
427         writeOffset += chunk.GetLength();
428         aMessage.GetNextChunk(aLength, chunk);
429     }
430 
431 exit:
432     return error;
433 }
434 
PrependBytes(const void * aBuf,uint16_t aLength)435 Error Message::PrependBytes(const void *aBuf, uint16_t aLength)
436 {
437     Error   error     = kErrorNone;
438     Buffer *newBuffer = nullptr;
439 
440     while (aLength > GetReserved())
441     {
442         VerifyOrExit((newBuffer = GetMessagePool()->NewBuffer(GetPriority())) != nullptr, error = kErrorNoBufs);
443 
444         newBuffer->SetNextBuffer(GetNextBuffer());
445         SetNextBuffer(newBuffer);
446 
447         if (GetReserved() < sizeof(mBuffer.mHead.mData))
448         {
449             // Copy payload from the first buffer.
450             memcpy(newBuffer->mBuffer.mHead.mData + GetReserved(), mBuffer.mHead.mData + GetReserved(),
451                    sizeof(mBuffer.mHead.mData) - GetReserved());
452         }
453 
454         SetReserved(GetReserved() + kBufferDataSize);
455     }
456 
457     SetReserved(GetReserved() - aLength);
458     GetMetadata().mLength += aLength;
459     SetOffset(GetOffset() + aLength);
460 
461     if (aBuf != nullptr)
462     {
463         WriteBytes(0, aBuf, aLength);
464     }
465 
466 exit:
467     return error;
468 }
469 
RemoveHeader(uint16_t aLength)470 void Message::RemoveHeader(uint16_t aLength)
471 {
472     OT_ASSERT(aLength <= GetMetadata().mLength);
473 
474     GetMetadata().mReserved += aLength;
475     GetMetadata().mLength -= aLength;
476 
477     if (GetMetadata().mOffset > aLength)
478     {
479         GetMetadata().mOffset -= aLength;
480     }
481     else
482     {
483         GetMetadata().mOffset = 0;
484     }
485 }
486 
RemoveHeader(uint16_t aOffset,uint16_t aLength)487 void Message::RemoveHeader(uint16_t aOffset, uint16_t aLength)
488 {
489     // To shrink the header, we copy the header byte before `aOffset`
490     // forward. Starting at offset `aLength`, we write bytes we read
491     // from offset `0` onward and copy a total of `aOffset` bytes.
492     // Then remove the first `aLength` bytes from message.
493     //
494     //
495     // 0                   aOffset  aOffset + aLength
496     // +-----------------------+---------+------------------------+
497     // | / / / / / / / / / / / | x x x x |                        |
498     // +-----------------------+---------+------------------------+
499     //
500     // 0       aLength                aOffset + aLength
501     // +---------+-----------------------+------------------------+
502     // |         | / / / / / / / / / / / |                        |
503     // +---------+-----------------------+------------------------+
504     //
505     //  0                    aOffset
506     //  +-----------------------+------------------------+
507     //  | / / / / / / / / / / / |                        |
508     //  +-----------------------+------------------------+
509     //
510 
511     WriteBytesFromMessage(/* aWriteOffset */ aLength, *this, /* aReadOffset */ 0, /* aLength */ aOffset);
512     RemoveHeader(aLength);
513 }
514 
InsertHeader(uint16_t aOffset,uint16_t aLength)515 Error Message::InsertHeader(uint16_t aOffset, uint16_t aLength)
516 {
517     Error error;
518 
519     // To make space in header at `aOffset`, we first prepend
520     // `aLength` bytes at front. Then copy the existing bytes
521     // backwards. Starting at offset `0`, we write bytes we read
522     // from offset `aLength` onward and copy a total of `aOffset`
523     // bytes.
524     //
525     // 0                    aOffset
526     // +-----------------------+------------------------+
527     // | / / / / / / / / / / / |                        |
528     // +-----------------------+------------------------+
529     //
530     // 0       aLength                aOffset + aLength
531     // +---------+-----------------------+------------------------+
532     // |         | / / / / / / / / / / / |                        |
533     // +---------+-----------------------+------------------------+
534     //
535     // 0                   aOffset  aOffset + aLength
536     // +-----------------------+---------+------------------------+
537     // | / / / / / / / / / / / |  N E W  |                        |
538     // +-----------------------+---------+------------------------+
539     //
540 
541     SuccessOrExit(error = PrependBytes(nullptr, aLength));
542     WriteBytesFromMessage(/* aWriteOffset */ 0, *this, /* aReadOffset */ aLength, /* aLength */ aOffset);
543 
544 exit:
545     return error;
546 }
547 
GetFirstChunk(uint16_t aOffset,uint16_t & aLength,Chunk & aChunk) const548 void Message::GetFirstChunk(uint16_t aOffset, uint16_t &aLength, Chunk &aChunk) const
549 {
550     // This method gets the first message chunk (contiguous data
551     // buffer) corresponding to a given offset and length. On exit
552     // `aChunk` is updated such that `aChunk.GetBytes()` gives the
553     // pointer to the start of chunk and `aChunk.GetLength()` gives
554     // its length. The `aLength` is also decreased by the chunk
555     // length.
556 
557     VerifyOrExit(aOffset < GetLength(), aChunk.SetLength(0));
558 
559     if (aOffset + aLength >= GetLength())
560     {
561         aLength = GetLength() - aOffset;
562     }
563 
564     aOffset += GetReserved();
565 
566     aChunk.SetBuffer(this);
567 
568     // Special case for the first buffer
569 
570     if (aOffset < kHeadBufferDataSize)
571     {
572         aChunk.Init(GetFirstData() + aOffset, kHeadBufferDataSize - aOffset);
573         ExitNow();
574     }
575 
576     aOffset -= kHeadBufferDataSize;
577 
578     // Find the `Buffer` matching the offset
579 
580     while (true)
581     {
582         aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
583 
584         OT_ASSERT(aChunk.GetBuffer() != nullptr);
585 
586         if (aOffset < kBufferDataSize)
587         {
588             aChunk.Init(aChunk.GetBuffer()->GetData() + aOffset, kBufferDataSize - aOffset);
589             ExitNow();
590         }
591 
592         aOffset -= kBufferDataSize;
593     }
594 
595 exit:
596     if (aChunk.GetLength() > aLength)
597     {
598         aChunk.SetLength(aLength);
599     }
600 
601     aLength -= aChunk.GetLength();
602 }
603 
GetNextChunk(uint16_t & aLength,Chunk & aChunk) const604 void Message::GetNextChunk(uint16_t &aLength, Chunk &aChunk) const
605 {
606     // This method gets the next message chunk. On input, the
607     // `aChunk` should be the previous chunk. On exit, it is
608     // updated to provide info about next chunk, and `aLength`
609     // is decreased by the chunk length. If there is no more
610     // chunk, `aChunk.GetLength()` would be zero.
611 
612     VerifyOrExit(aLength > 0, aChunk.SetLength(0));
613 
614     aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
615 
616     OT_ASSERT(aChunk.GetBuffer() != nullptr);
617 
618     aChunk.Init(aChunk.GetBuffer()->GetData(), kBufferDataSize);
619 
620     if (aChunk.GetLength() > aLength)
621     {
622         aChunk.SetLength(aLength);
623     }
624 
625     aLength -= aChunk.GetLength();
626 
627 exit:
628     return;
629 }
630 
ReadBytes(uint16_t aOffset,void * aBuf,uint16_t aLength) const631 uint16_t Message::ReadBytes(uint16_t aOffset, void *aBuf, uint16_t aLength) const
632 {
633     uint8_t *bufPtr = reinterpret_cast<uint8_t *>(aBuf);
634     Chunk    chunk;
635 
636     GetFirstChunk(aOffset, aLength, chunk);
637 
638     while (chunk.GetLength() > 0)
639     {
640         chunk.CopyBytesTo(bufPtr);
641         bufPtr += chunk.GetLength();
642         GetNextChunk(aLength, chunk);
643     }
644 
645     return static_cast<uint16_t>(bufPtr - reinterpret_cast<uint8_t *>(aBuf));
646 }
647 
Read(uint16_t aOffset,void * aBuf,uint16_t aLength) const648 Error Message::Read(uint16_t aOffset, void *aBuf, uint16_t aLength) const
649 {
650     return (ReadBytes(aOffset, aBuf, aLength) == aLength) ? kErrorNone : kErrorParse;
651 }
652 
CompareBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength,ByteMatcher aMatcher) const653 bool Message::CompareBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength, ByteMatcher aMatcher) const
654 {
655     uint16_t       bytesToCompare = aLength;
656     const uint8_t *bufPtr         = reinterpret_cast<const uint8_t *>(aBuf);
657     Chunk          chunk;
658 
659     GetFirstChunk(aOffset, aLength, chunk);
660 
661     while (chunk.GetLength() > 0)
662     {
663         VerifyOrExit(chunk.MatchesBytesIn(bufPtr, aMatcher));
664         bufPtr += chunk.GetLength();
665         bytesToCompare -= chunk.GetLength();
666         GetNextChunk(aLength, chunk);
667     }
668 
669 exit:
670     return (bytesToCompare == 0);
671 }
672 
CompareBytes(uint16_t aOffset,const Message & aOtherMessage,uint16_t aOtherOffset,uint16_t aLength,ByteMatcher aMatcher) const673 bool Message::CompareBytes(uint16_t       aOffset,
674                            const Message &aOtherMessage,
675                            uint16_t       aOtherOffset,
676                            uint16_t       aLength,
677                            ByteMatcher    aMatcher) const
678 {
679     uint16_t bytesToCompare = aLength;
680     Chunk    chunk;
681 
682     GetFirstChunk(aOffset, aLength, chunk);
683 
684     while (chunk.GetLength() > 0)
685     {
686         VerifyOrExit(aOtherMessage.CompareBytes(aOtherOffset, chunk.GetBytes(), chunk.GetLength(), aMatcher));
687         aOtherOffset += chunk.GetLength();
688         bytesToCompare -= chunk.GetLength();
689         GetNextChunk(aLength, chunk);
690     }
691 
692 exit:
693     return (bytesToCompare == 0);
694 }
695 
WriteBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength)696 void Message::WriteBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength)
697 {
698     const uint8_t *bufPtr = reinterpret_cast<const uint8_t *>(aBuf);
699     MutableChunk   chunk;
700 
701     OT_ASSERT(aOffset + aLength <= GetLength());
702 
703     GetFirstChunk(aOffset, aLength, chunk);
704 
705     while (chunk.GetLength() > 0)
706     {
707         memmove(chunk.GetBytes(), bufPtr, chunk.GetLength());
708         bufPtr += chunk.GetLength();
709         GetNextChunk(aLength, chunk);
710     }
711 }
712 
WriteBytesFromMessage(uint16_t aWriteOffset,const Message & aMessage,uint16_t aReadOffset,uint16_t aLength)713 void Message::WriteBytesFromMessage(uint16_t       aWriteOffset,
714                                     const Message &aMessage,
715                                     uint16_t       aReadOffset,
716                                     uint16_t       aLength)
717 {
718     if ((&aMessage != this) || (aReadOffset >= aWriteOffset))
719     {
720         Chunk chunk;
721 
722         aMessage.GetFirstChunk(aReadOffset, aLength, chunk);
723 
724         while (chunk.GetLength() > 0)
725         {
726             WriteBytes(aWriteOffset, chunk.GetBytes(), chunk.GetLength());
727             aWriteOffset += chunk.GetLength();
728             aMessage.GetNextChunk(aLength, chunk);
729         }
730     }
731     else
732     {
733         // We are copying bytes within the same message forward.
734         // To ensure copy forward works, we read and write from
735         // end of range and move backwards.
736 
737         static constexpr uint16_t kBufSize = 32;
738 
739         uint8_t buf[kBufSize];
740 
741         aWriteOffset += aLength;
742         aReadOffset += aLength;
743 
744         while (aLength > 0)
745         {
746             uint16_t copyLength = Min(kBufSize, aLength);
747 
748             aLength -= copyLength;
749             aReadOffset -= copyLength;
750             aWriteOffset -= copyLength;
751 
752             ReadBytes(aReadOffset, buf, copyLength);
753             WriteBytes(aWriteOffset, buf, copyLength);
754         }
755     }
756 }
757 
Clone(uint16_t aLength) const758 Message *Message::Clone(uint16_t aLength) const
759 {
760     Error    error = kErrorNone;
761     Message *messageCopy;
762     Settings settings(IsLinkSecurityEnabled() ? kWithLinkSecurity : kNoLinkSecurity, GetPriority());
763     uint16_t offset;
764 
765     aLength     = Min(GetLength(), aLength);
766     messageCopy = GetMessagePool()->Allocate(GetType(), GetReserved(), settings);
767     VerifyOrExit(messageCopy != nullptr, error = kErrorNoBufs);
768     SuccessOrExit(error = messageCopy->AppendBytesFromMessage(*this, 0, aLength));
769 
770     // Copy selected message information.
771     offset = Min(GetOffset(), aLength);
772     messageCopy->SetOffset(offset);
773 
774     messageCopy->SetSubType(GetSubType());
775 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
776     messageCopy->SetTimeSync(IsTimeSync());
777 #endif
778 
779 exit:
780     FreeAndNullMessageOnError(messageCopy, error);
781     return messageCopy;
782 }
783 
784 #if OPENTHREAD_FTD
GetChildMask(uint16_t aChildIndex) const785 bool Message::GetChildMask(uint16_t aChildIndex) const { return GetMetadata().mChildMask.Get(aChildIndex); }
786 
ClearChildMask(uint16_t aChildIndex)787 void Message::ClearChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, false); }
788 
SetChildMask(uint16_t aChildIndex)789 void Message::SetChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, true); }
790 
IsChildPending(void) const791 bool Message::IsChildPending(void) const { return GetMetadata().mChildMask.HasAny(); }
792 #endif
793 
SetLinkInfo(const ThreadLinkInfo & aLinkInfo)794 void Message::SetLinkInfo(const ThreadLinkInfo &aLinkInfo)
795 {
796     SetLinkSecurityEnabled(aLinkInfo.mLinkSecurity);
797     SetPanId(aLinkInfo.mPanId);
798     AddRss(aLinkInfo.mRss);
799 #if OPENTHREAD_CONFIG_MLE_LINK_METRICS_SUBJECT_ENABLE
800     AddLqi(aLinkInfo.mLqi);
801 #endif
802 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
803     SetTimeSyncSeq(aLinkInfo.mTimeSyncSeq);
804     SetNetworkTimeOffset(aLinkInfo.mNetworkTimeOffset);
805 #endif
806 #if OPENTHREAD_CONFIG_MULTI_RADIO
807     SetRadioType(static_cast<Mac::RadioType>(aLinkInfo.mRadioType));
808 #endif
809 }
810 
IsTimeSync(void) const811 bool Message::IsTimeSync(void) const
812 {
813 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
814     return GetMetadata().mTimeSync;
815 #else
816     return false;
817 #endif
818 }
819 
SetMessageQueue(MessageQueue * aMessageQueue)820 void Message::SetMessageQueue(MessageQueue *aMessageQueue)
821 {
822     GetMetadata().mQueue       = aMessageQueue;
823     GetMetadata().mInPriorityQ = false;
824 }
825 
SetPriorityQueue(PriorityQueue * aPriorityQueue)826 void Message::SetPriorityQueue(PriorityQueue *aPriorityQueue)
827 {
828     GetMetadata().mQueue       = aPriorityQueue;
829     GetMetadata().mInPriorityQ = true;
830 }
831 
832 //---------------------------------------------------------------------------------------------------------------------
833 // MessageQueue
834 
Enqueue(Message & aMessage,QueuePosition aPosition)835 void MessageQueue::Enqueue(Message &aMessage, QueuePosition aPosition)
836 {
837     OT_ASSERT(!aMessage.IsInAQueue());
838     OT_ASSERT((aMessage.Next() == nullptr) && (aMessage.Prev() == nullptr));
839 
840     aMessage.SetMessageQueue(this);
841 
842     if (GetTail() == nullptr)
843     {
844         aMessage.Next() = &aMessage;
845         aMessage.Prev() = &aMessage;
846 
847         SetTail(&aMessage);
848     }
849     else
850     {
851         Message *head = GetTail()->Next();
852 
853         aMessage.Next() = head;
854         aMessage.Prev() = GetTail();
855 
856         head->Prev()      = &aMessage;
857         GetTail()->Next() = &aMessage;
858 
859         if (aPosition == kQueuePositionTail)
860         {
861             SetTail(&aMessage);
862         }
863     }
864 }
865 
Dequeue(Message & aMessage)866 void MessageQueue::Dequeue(Message &aMessage)
867 {
868     OT_ASSERT(aMessage.GetMessageQueue() == this);
869     OT_ASSERT((aMessage.Next() != nullptr) && (aMessage.Prev() != nullptr));
870 
871     if (&aMessage == GetTail())
872     {
873         SetTail(GetTail()->Prev());
874 
875         if (&aMessage == GetTail())
876         {
877             SetTail(nullptr);
878         }
879     }
880 
881     aMessage.Prev()->Next() = aMessage.Next();
882     aMessage.Next()->Prev() = aMessage.Prev();
883 
884     aMessage.Prev() = nullptr;
885     aMessage.Next() = nullptr;
886 
887     aMessage.SetMessageQueue(nullptr);
888 }
889 
DequeueAndFree(Message & aMessage)890 void MessageQueue::DequeueAndFree(Message &aMessage)
891 {
892     Dequeue(aMessage);
893     aMessage.Free();
894 }
895 
DequeueAndFreeAll(void)896 void MessageQueue::DequeueAndFreeAll(void)
897 {
898     Message *message;
899 
900     while ((message = GetHead()) != nullptr)
901     {
902         DequeueAndFree(*message);
903     }
904 }
905 
begin(void)906 Message::Iterator MessageQueue::begin(void) { return Message::Iterator(GetHead()); }
907 
begin(void) const908 Message::ConstIterator MessageQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
909 
GetInfo(Info & aInfo) const910 void MessageQueue::GetInfo(Info &aInfo) const
911 {
912     for (const Message &message : *this)
913     {
914         aInfo.mNumMessages++;
915         aInfo.mNumBuffers += message.GetBufferCount();
916         aInfo.mTotalBytes += message.GetLength();
917     }
918 }
919 
920 //---------------------------------------------------------------------------------------------------------------------
921 // PriorityQueue
922 
FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const923 const Message *PriorityQueue::FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const
924 {
925     // Find the first non-`nullptr` tail starting from the given priority
926     // level and moving forward (wrapping from priority value
927     // `kNumPriorities` -1 back to 0).
928 
929     const Message *tail = nullptr;
930     uint8_t        priority;
931 
932     priority = static_cast<uint8_t>(aStartPriorityLevel);
933 
934     do
935     {
936         if (mTails[priority] != nullptr)
937         {
938             tail = mTails[priority];
939             break;
940         }
941 
942         priority = PrevPriority(priority);
943     } while (priority != aStartPriorityLevel);
944 
945     return tail;
946 }
947 
GetHead(void) const948 const Message *PriorityQueue::GetHead(void) const
949 {
950     return Message::NextOf(FindFirstNonNullTail(Message::kPriorityLow));
951 }
952 
GetHeadForPriority(Message::Priority aPriority) const953 const Message *PriorityQueue::GetHeadForPriority(Message::Priority aPriority) const
954 {
955     const Message *head;
956     const Message *previousTail;
957 
958     if (mTails[aPriority] != nullptr)
959     {
960         previousTail = FindFirstNonNullTail(static_cast<Message::Priority>(PrevPriority(aPriority)));
961 
962         OT_ASSERT(previousTail != nullptr);
963 
964         head = previousTail->Next();
965     }
966     else
967     {
968         head = nullptr;
969     }
970 
971     return head;
972 }
973 
GetTail(void) const974 const Message *PriorityQueue::GetTail(void) const { return FindFirstNonNullTail(Message::kPriorityLow); }
975 
Enqueue(Message & aMessage)976 void PriorityQueue::Enqueue(Message &aMessage)
977 {
978     Message::Priority priority;
979     Message          *tail;
980     Message          *next;
981 
982     OT_ASSERT(!aMessage.IsInAQueue());
983 
984     aMessage.SetPriorityQueue(this);
985 
986     priority = aMessage.GetPriority();
987 
988     tail = FindFirstNonNullTail(priority);
989 
990     if (tail != nullptr)
991     {
992         next = tail->Next();
993 
994         aMessage.Next() = next;
995         aMessage.Prev() = tail;
996         next->Prev()    = &aMessage;
997         tail->Next()    = &aMessage;
998     }
999     else
1000     {
1001         aMessage.Next() = &aMessage;
1002         aMessage.Prev() = &aMessage;
1003     }
1004 
1005     mTails[priority] = &aMessage;
1006 }
1007 
Dequeue(Message & aMessage)1008 void PriorityQueue::Dequeue(Message &aMessage)
1009 {
1010     Message::Priority priority;
1011     Message          *tail;
1012 
1013     OT_ASSERT(aMessage.GetPriorityQueue() == this);
1014 
1015     priority = aMessage.GetPriority();
1016 
1017     tail = mTails[priority];
1018 
1019     if (&aMessage == tail)
1020     {
1021         tail = tail->Prev();
1022 
1023         if ((&aMessage == tail) || (tail->GetPriority() != priority))
1024         {
1025             tail = nullptr;
1026         }
1027 
1028         mTails[priority] = tail;
1029     }
1030 
1031     aMessage.Next()->Prev() = aMessage.Prev();
1032     aMessage.Prev()->Next() = aMessage.Next();
1033     aMessage.Next()         = nullptr;
1034     aMessage.Prev()         = nullptr;
1035 
1036     aMessage.SetPriorityQueue(nullptr);
1037 }
1038 
DequeueAndFree(Message & aMessage)1039 void PriorityQueue::DequeueAndFree(Message &aMessage)
1040 {
1041     Dequeue(aMessage);
1042     aMessage.Free();
1043 }
1044 
DequeueAndFreeAll(void)1045 void PriorityQueue::DequeueAndFreeAll(void)
1046 {
1047     Message *message;
1048 
1049     while ((message = GetHead()) != nullptr)
1050     {
1051         DequeueAndFree(*message);
1052     }
1053 }
1054 
begin(void)1055 Message::Iterator PriorityQueue::begin(void) { return Message::Iterator(GetHead()); }
1056 
begin(void) const1057 Message::ConstIterator PriorityQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
1058 
GetInfo(Info & aInfo) const1059 void PriorityQueue::GetInfo(Info &aInfo) const
1060 {
1061     for (const Message &message : *this)
1062     {
1063         aInfo.mNumMessages++;
1064         aInfo.mNumBuffers += message.GetBufferCount();
1065         aInfo.mTotalBytes += message.GetLength();
1066     }
1067 }
1068 
1069 } // namespace ot
1070 #endif // OPENTHREAD_MTD || OPENTHREAD_FTD
1071