1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the message buffer pool and message buffers.
32  */
33 
34 #include "message.hpp"
35 
36 #include "common/as_core_type.hpp"
37 #include "common/code_utils.hpp"
38 #include "common/debug.hpp"
39 #include "common/heap.hpp"
40 #include "common/locator_getters.hpp"
41 #include "common/log.hpp"
42 #include "common/num_utils.hpp"
43 #include "common/numeric_limits.hpp"
44 #include "instance/instance.hpp"
45 #include "net/checksum.hpp"
46 #include "net/ip6.hpp"
47 
48 #if OPENTHREAD_MTD || OPENTHREAD_FTD
49 
50 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE && OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
51 #error "OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE conflicts with OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT."
52 #endif
53 
54 namespace ot {
55 
56 RegisterLogModule("Message");
57 
58 //---------------------------------------------------------------------------------------------------------------------
59 // MessagePool
60 
MessagePool(Instance & aInstance)61 MessagePool::MessagePool(Instance &aInstance)
62     : InstanceLocator(aInstance)
63     , mNumAllocated(0)
64     , mMaxAllocated(0)
65 {
66 #if OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
67     otPlatMessagePoolInit(&GetInstance(), kNumBuffers, sizeof(Buffer));
68 #endif
69 }
70 
Allocate(Message::Type aType,uint16_t aReserveHeader,const Message::Settings & aSettings)71 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader, const Message::Settings &aSettings)
72 {
73     Error    error = kErrorNone;
74     Message *message;
75 
76     VerifyOrExit((message = static_cast<Message *>(NewBuffer(aSettings.GetPriority()))) != nullptr);
77 
78     ClearAllBytes(*message);
79     message->SetMessagePool(this);
80     message->SetType(aType);
81     message->SetReserved(aReserveHeader);
82     message->SetLinkSecurityEnabled(aSettings.IsLinkSecurityEnabled());
83     message->SetLoopbackToHostAllowed(OPENTHREAD_CONFIG_IP6_ALLOW_LOOP_BACK_HOST_DATAGRAMS);
84     message->SetOrigin(Message::kOriginHostTrusted);
85 
86     SuccessOrExit(error = message->SetPriority(aSettings.GetPriority()));
87     SuccessOrExit(error = message->SetLength(0));
88 
89 exit:
90     if (error != kErrorNone)
91     {
92         Free(message);
93         message = nullptr;
94     }
95 
96     return message;
97 }
98 
Allocate(Message::Type aType)99 Message *MessagePool::Allocate(Message::Type aType) { return Allocate(aType, 0, Message::Settings::GetDefault()); }
100 
Allocate(Message::Type aType,uint16_t aReserveHeader)101 Message *MessagePool::Allocate(Message::Type aType, uint16_t aReserveHeader)
102 {
103     return Allocate(aType, aReserveHeader, Message::Settings::GetDefault());
104 }
105 
Free(Message * aMessage)106 void MessagePool::Free(Message *aMessage)
107 {
108     OT_ASSERT(aMessage->Next() == nullptr && aMessage->Prev() == nullptr);
109 
110     FreeBuffers(static_cast<Buffer *>(aMessage));
111 }
112 
NewBuffer(Message::Priority aPriority)113 Buffer *MessagePool::NewBuffer(Message::Priority aPriority)
114 {
115     Buffer *buffer = nullptr;
116 
117     while ((
118 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
119                buffer = static_cast<Buffer *>(Heap::CAlloc(1, sizeof(Buffer)))
120 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
121                buffer = static_cast<Buffer *>(otPlatMessagePoolNew(&GetInstance()))
122 #else
123                buffer = mBufferPool.Allocate()
124 #endif
125                    ) == nullptr)
126     {
127         SuccessOrExit(ReclaimBuffers(aPriority));
128     }
129 
130     mNumAllocated++;
131     mMaxAllocated = Max(mMaxAllocated, mNumAllocated);
132 
133     buffer->SetNextBuffer(nullptr);
134 
135 exit:
136     if (buffer == nullptr)
137     {
138         LogInfo("No available message buffer");
139     }
140 
141     return buffer;
142 }
143 
FreeBuffers(Buffer * aBuffer)144 void MessagePool::FreeBuffers(Buffer *aBuffer)
145 {
146     while (aBuffer != nullptr)
147     {
148         Buffer *next = aBuffer->GetNextBuffer();
149 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
150         Heap::Free(aBuffer);
151 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
152         otPlatMessagePoolFree(&GetInstance(), aBuffer);
153 #else
154         mBufferPool.Free(*aBuffer);
155 #endif
156         mNumAllocated--;
157 
158         aBuffer = next;
159     }
160 }
161 
ReclaimBuffers(Message::Priority aPriority)162 Error MessagePool::ReclaimBuffers(Message::Priority aPriority) { return Get<MeshForwarder>().EvictMessage(aPriority); }
163 
GetFreeBufferCount(void) const164 uint16_t MessagePool::GetFreeBufferCount(void) const
165 {
166     uint16_t rval;
167 
168 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
169 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
170     rval = static_cast<uint16_t>(Instance::GetHeap().GetFreeSize() / sizeof(Buffer));
171 #else
172     rval = NumericLimits<uint16_t>::kMax;
173 #endif
174 #elif OPENTHREAD_CONFIG_PLATFORM_MESSAGE_MANAGEMENT
175     rval = otPlatMessagePoolNumFreeBuffers(&GetInstance());
176 #else
177     rval = kNumBuffers - mNumAllocated;
178 #endif
179 
180     return rval;
181 }
182 
GetTotalBufferCount(void) const183 uint16_t MessagePool::GetTotalBufferCount(void) const
184 {
185     uint16_t rval;
186 
187 #if OPENTHREAD_CONFIG_MESSAGE_USE_HEAP_ENABLE
188 #if !OPENTHREAD_CONFIG_HEAP_EXTERNAL_ENABLE
189     rval = static_cast<uint16_t>(Instance::GetHeap().GetCapacity() / sizeof(Buffer));
190 #else
191     rval = NumericLimits<uint16_t>::kMax;
192 #endif
193 #else
194     rval = OPENTHREAD_CONFIG_NUM_MESSAGE_BUFFERS;
195 #endif
196 
197     return rval;
198 }
199 
200 //---------------------------------------------------------------------------------------------------------------------
201 // Message::Settings
202 
203 const otMessageSettings Message::Settings::kDefault = {kWithLinkSecurity, kPriorityNormal};
204 
Settings(LinkSecurityMode aSecurityMode,Priority aPriority)205 Message::Settings::Settings(LinkSecurityMode aSecurityMode, Priority aPriority)
206 {
207     mLinkSecurityEnabled = aSecurityMode;
208     mPriority            = aPriority;
209 }
210 
From(const otMessageSettings * aSettings)211 const Message::Settings &Message::Settings::From(const otMessageSettings *aSettings)
212 {
213     return (aSettings == nullptr) ? GetDefault() : AsCoreType(aSettings);
214 }
215 
216 //---------------------------------------------------------------------------------------------------------------------
217 // Message::Iterator
218 
Advance(void)219 void Message::Iterator::Advance(void)
220 {
221     mItem = mNext;
222     mNext = NextMessage(mNext);
223 }
224 
225 //---------------------------------------------------------------------------------------------------------------------
226 // Message
227 
ResizeMessage(uint16_t aLength)228 Error Message::ResizeMessage(uint16_t aLength)
229 {
230     // This method adds or frees message buffers to meet the
231     // requested length.
232 
233     Error    error     = kErrorNone;
234     Buffer  *curBuffer = this;
235     Buffer  *lastBuffer;
236     uint16_t curLength = kHeadBufferDataSize;
237 
238     while (curLength < aLength)
239     {
240         if (curBuffer->GetNextBuffer() == nullptr)
241         {
242             curBuffer->SetNextBuffer(GetMessagePool()->NewBuffer(GetPriority()));
243             VerifyOrExit(curBuffer->GetNextBuffer() != nullptr, error = kErrorNoBufs);
244         }
245 
246         curBuffer = curBuffer->GetNextBuffer();
247         curLength += kBufferDataSize;
248     }
249 
250     lastBuffer = curBuffer;
251     curBuffer  = curBuffer->GetNextBuffer();
252     lastBuffer->SetNextBuffer(nullptr);
253 
254     GetMessagePool()->FreeBuffers(curBuffer);
255 
256 exit:
257     return error;
258 }
259 
Free(void)260 void Message::Free(void) { GetMessagePool()->Free(this); }
261 
GetNext(void) const262 Message *Message::GetNext(void) const
263 {
264     Message *next;
265     Message *tail;
266 
267     if (GetMetadata().mInPriorityQ)
268     {
269         PriorityQueue *priorityQueue = GetPriorityQueue();
270         VerifyOrExit(priorityQueue != nullptr, next = nullptr);
271         tail = priorityQueue->GetTail();
272     }
273     else
274     {
275         MessageQueue *messageQueue = GetMessageQueue();
276         VerifyOrExit(messageQueue != nullptr, next = nullptr);
277         tail = messageQueue->GetTail();
278     }
279 
280     next = (this == tail) ? nullptr : Next();
281 
282 exit:
283     return next;
284 }
285 
SetLength(uint16_t aLength)286 Error Message::SetLength(uint16_t aLength)
287 {
288     Error    error              = kErrorNone;
289     uint16_t totalLengthRequest = GetReserved() + aLength;
290 
291     VerifyOrExit(totalLengthRequest >= GetReserved(), error = kErrorInvalidArgs);
292 
293     SuccessOrExit(error = ResizeMessage(totalLengthRequest));
294     GetMetadata().mLength = aLength;
295 
296     // Correct the offset in case shorter length is set.
297     if (GetOffset() > aLength)
298     {
299         SetOffset(aLength);
300     }
301 
302 exit:
303     return error;
304 }
305 
GetBufferCount(void) const306 uint8_t Message::GetBufferCount(void) const
307 {
308     uint8_t rval = 1;
309 
310     for (const Buffer *curBuffer = GetNextBuffer(); curBuffer; curBuffer = curBuffer->GetNextBuffer())
311     {
312         rval++;
313     }
314 
315     return rval;
316 }
317 
MoveOffset(int aDelta)318 void Message::MoveOffset(int aDelta)
319 {
320     OT_ASSERT(GetOffset() + aDelta <= GetLength());
321     GetMetadata().mOffset += static_cast<int16_t>(aDelta);
322     OT_ASSERT(GetMetadata().mOffset <= GetLength());
323 }
324 
SetOffset(uint16_t aOffset)325 void Message::SetOffset(uint16_t aOffset)
326 {
327     OT_ASSERT(aOffset <= GetLength());
328     GetMetadata().mOffset = aOffset;
329 }
330 
IsSubTypeMle(void) const331 bool Message::IsSubTypeMle(void) const
332 {
333     bool rval;
334 
335     switch (GetMetadata().mSubType)
336     {
337     case kSubTypeMleGeneral:
338     case kSubTypeMleAnnounce:
339     case kSubTypeMleDiscoverRequest:
340     case kSubTypeMleDiscoverResponse:
341     case kSubTypeMleChildUpdateRequest:
342     case kSubTypeMleDataResponse:
343     case kSubTypeMleChildIdRequest:
344         rval = true;
345         break;
346 
347     default:
348         rval = false;
349         break;
350     }
351 
352     return rval;
353 }
354 
SetPriority(Priority aPriority)355 Error Message::SetPriority(Priority aPriority)
356 {
357     Error          error    = kErrorNone;
358     uint8_t        priority = static_cast<uint8_t>(aPriority);
359     PriorityQueue *priorityQueue;
360 
361     static_assert(kNumPriorities <= 4, "`Metadata::mPriority` as a 2-bit field cannot fit all `Priority` values");
362 
363     VerifyOrExit(priority < kNumPriorities, error = kErrorInvalidArgs);
364 
365     VerifyOrExit(IsInAQueue(), GetMetadata().mPriority = priority);
366     VerifyOrExit(GetMetadata().mPriority != priority);
367 
368     priorityQueue = GetPriorityQueue();
369 
370     if (priorityQueue != nullptr)
371     {
372         priorityQueue->Dequeue(*this);
373     }
374 
375     GetMetadata().mPriority = priority;
376 
377     if (priorityQueue != nullptr)
378     {
379         priorityQueue->Enqueue(*this);
380     }
381 
382 exit:
383     return error;
384 }
385 
PriorityToString(Priority aPriority)386 const char *Message::PriorityToString(Priority aPriority)
387 {
388     static const char *const kPriorityStrings[] = {
389         "low",    // (0) kPriorityLow
390         "normal", // (1) kPriorityNormal
391         "high",   // (2) kPriorityHigh
392         "net",    // (3) kPriorityNet
393     };
394 
395     static_assert(kPriorityLow == 0, "kPriorityLow value is incorrect");
396     static_assert(kPriorityNormal == 1, "kPriorityNormal value is incorrect");
397     static_assert(kPriorityHigh == 2, "kPriorityHigh value is incorrect");
398     static_assert(kPriorityNet == 3, "kPriorityNet value is incorrect");
399 
400     return kPriorityStrings[aPriority];
401 }
402 
AppendBytes(const void * aBuf,uint16_t aLength)403 Error Message::AppendBytes(const void *aBuf, uint16_t aLength)
404 {
405     Error    error     = kErrorNone;
406     uint16_t oldLength = GetLength();
407 
408     SuccessOrExit(error = SetLength(GetLength() + aLength));
409     WriteBytes(oldLength, aBuf, aLength);
410 
411 exit:
412     return error;
413 }
414 
AppendBytesFromMessage(const Message & aMessage,const OffsetRange & aOffsetRange)415 Error Message::AppendBytesFromMessage(const Message &aMessage, const OffsetRange &aOffsetRange)
416 {
417     return AppendBytesFromMessage(aMessage, aOffsetRange.GetOffset(), aOffsetRange.GetLength());
418 }
419 
AppendBytesFromMessage(const Message & aMessage,uint16_t aOffset,uint16_t aLength)420 Error Message::AppendBytesFromMessage(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
421 {
422     Error    error       = kErrorNone;
423     uint16_t writeOffset = GetLength();
424     Chunk    chunk;
425 
426     VerifyOrExit(aMessage.GetLength() >= aOffset + aLength, error = kErrorParse);
427     SuccessOrExit(error = SetLength(GetLength() + aLength));
428 
429     aMessage.GetFirstChunk(aOffset, aLength, chunk);
430 
431     while (chunk.GetLength() > 0)
432     {
433         WriteBytes(writeOffset, chunk.GetBytes(), chunk.GetLength());
434         writeOffset += chunk.GetLength();
435         aMessage.GetNextChunk(aLength, chunk);
436     }
437 
438 exit:
439     return error;
440 }
441 
PrependBytes(const void * aBuf,uint16_t aLength)442 Error Message::PrependBytes(const void *aBuf, uint16_t aLength)
443 {
444     Error   error     = kErrorNone;
445     Buffer *newBuffer = nullptr;
446 
447     while (aLength > GetReserved())
448     {
449         VerifyOrExit((newBuffer = GetMessagePool()->NewBuffer(GetPriority())) != nullptr, error = kErrorNoBufs);
450 
451         newBuffer->SetNextBuffer(GetNextBuffer());
452         SetNextBuffer(newBuffer);
453 
454         if (GetReserved() < sizeof(mBuffer.mHead.mData))
455         {
456             // Copy payload from the first buffer.
457             memcpy(newBuffer->mBuffer.mHead.mData + GetReserved(), mBuffer.mHead.mData + GetReserved(),
458                    sizeof(mBuffer.mHead.mData) - GetReserved());
459         }
460 
461         SetReserved(GetReserved() + kBufferDataSize);
462     }
463 
464     SetReserved(GetReserved() - aLength);
465     GetMetadata().mLength += aLength;
466     SetOffset(GetOffset() + aLength);
467 
468     if (aBuf != nullptr)
469     {
470         WriteBytes(0, aBuf, aLength);
471     }
472 
473 exit:
474     return error;
475 }
476 
RemoveHeader(uint16_t aLength)477 void Message::RemoveHeader(uint16_t aLength)
478 {
479     OT_ASSERT(aLength <= GetMetadata().mLength);
480 
481     GetMetadata().mReserved += aLength;
482     GetMetadata().mLength -= aLength;
483 
484     if (GetMetadata().mOffset > aLength)
485     {
486         GetMetadata().mOffset -= aLength;
487     }
488     else
489     {
490         GetMetadata().mOffset = 0;
491     }
492 }
493 
RemoveHeader(uint16_t aOffset,uint16_t aLength)494 void Message::RemoveHeader(uint16_t aOffset, uint16_t aLength)
495 {
496     // To shrink the header, we copy the header byte before `aOffset`
497     // forward. Starting at offset `aLength`, we write bytes we read
498     // from offset `0` onward and copy a total of `aOffset` bytes.
499     // Then remove the first `aLength` bytes from message.
500     //
501     //
502     // 0                   aOffset  aOffset + aLength
503     // +-----------------------+---------+------------------------+
504     // | / / / / / / / / / / / | x x x x |                        |
505     // +-----------------------+---------+------------------------+
506     //
507     // 0       aLength                aOffset + aLength
508     // +---------+-----------------------+------------------------+
509     // |         | / / / / / / / / / / / |                        |
510     // +---------+-----------------------+------------------------+
511     //
512     //  0                    aOffset
513     //  +-----------------------+------------------------+
514     //  | / / / / / / / / / / / |                        |
515     //  +-----------------------+------------------------+
516     //
517 
518     WriteBytesFromMessage(/* aWriteOffset */ aLength, *this, /* aReadOffset */ 0, /* aLength */ aOffset);
519     RemoveHeader(aLength);
520 }
521 
InsertHeader(uint16_t aOffset,uint16_t aLength)522 Error Message::InsertHeader(uint16_t aOffset, uint16_t aLength)
523 {
524     Error error;
525 
526     // To make space in header at `aOffset`, we first prepend
527     // `aLength` bytes at front. Then copy the existing bytes
528     // backwards. Starting at offset `0`, we write bytes we read
529     // from offset `aLength` onward and copy a total of `aOffset`
530     // bytes.
531     //
532     // 0                    aOffset
533     // +-----------------------+------------------------+
534     // | / / / / / / / / / / / |                        |
535     // +-----------------------+------------------------+
536     //
537     // 0       aLength                aOffset + aLength
538     // +---------+-----------------------+------------------------+
539     // |         | / / / / / / / / / / / |                        |
540     // +---------+-----------------------+------------------------+
541     //
542     // 0                   aOffset  aOffset + aLength
543     // +-----------------------+---------+------------------------+
544     // | / / / / / / / / / / / |  N E W  |                        |
545     // +-----------------------+---------+------------------------+
546     //
547 
548     SuccessOrExit(error = PrependBytes(nullptr, aLength));
549     WriteBytesFromMessage(/* aWriteOffset */ 0, *this, /* aReadOffset */ aLength, /* aLength */ aOffset);
550 
551 exit:
552     return error;
553 }
554 
RemoveFooter(uint16_t aLength)555 void Message::RemoveFooter(uint16_t aLength) { IgnoreError(SetLength(GetLength() - Min(aLength, GetLength()))); }
556 
GetFirstChunk(uint16_t aOffset,uint16_t & aLength,Chunk & aChunk) const557 void Message::GetFirstChunk(uint16_t aOffset, uint16_t &aLength, Chunk &aChunk) const
558 {
559     // This method gets the first message chunk (contiguous data
560     // buffer) corresponding to a given offset and length. On exit
561     // `aChunk` is updated such that `aChunk.GetBytes()` gives the
562     // pointer to the start of chunk and `aChunk.GetLength()` gives
563     // its length. The `aLength` is also decreased by the chunk
564     // length.
565 
566     VerifyOrExit(aOffset < GetLength(), aChunk.SetLength(0));
567 
568     if (aOffset + aLength >= GetLength())
569     {
570         aLength = GetLength() - aOffset;
571     }
572 
573     aOffset += GetReserved();
574 
575     aChunk.SetBuffer(this);
576 
577     // Special case for the first buffer
578 
579     if (aOffset < kHeadBufferDataSize)
580     {
581         aChunk.Init(GetFirstData() + aOffset, kHeadBufferDataSize - aOffset);
582         ExitNow();
583     }
584 
585     aOffset -= kHeadBufferDataSize;
586 
587     // Find the `Buffer` matching the offset
588 
589     while (true)
590     {
591         aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
592 
593         OT_ASSERT(aChunk.GetBuffer() != nullptr);
594 
595         if (aOffset < kBufferDataSize)
596         {
597             aChunk.Init(aChunk.GetBuffer()->GetData() + aOffset, kBufferDataSize - aOffset);
598             ExitNow();
599         }
600 
601         aOffset -= kBufferDataSize;
602     }
603 
604 exit:
605     if (aChunk.GetLength() > aLength)
606     {
607         aChunk.SetLength(aLength);
608     }
609 
610     aLength -= aChunk.GetLength();
611 }
612 
GetNextChunk(uint16_t & aLength,Chunk & aChunk) const613 void Message::GetNextChunk(uint16_t &aLength, Chunk &aChunk) const
614 {
615     // This method gets the next message chunk. On input, the
616     // `aChunk` should be the previous chunk. On exit, it is
617     // updated to provide info about next chunk, and `aLength`
618     // is decreased by the chunk length. If there is no more
619     // chunk, `aChunk.GetLength()` would be zero.
620 
621     VerifyOrExit(aLength > 0, aChunk.SetLength(0));
622 
623     aChunk.SetBuffer(aChunk.GetBuffer()->GetNextBuffer());
624 
625     OT_ASSERT(aChunk.GetBuffer() != nullptr);
626 
627     aChunk.Init(aChunk.GetBuffer()->GetData(), kBufferDataSize);
628 
629     if (aChunk.GetLength() > aLength)
630     {
631         aChunk.SetLength(aLength);
632     }
633 
634     aLength -= aChunk.GetLength();
635 
636 exit:
637     return;
638 }
639 
ReadBytes(uint16_t aOffset,void * aBuf,uint16_t aLength) const640 uint16_t Message::ReadBytes(uint16_t aOffset, void *aBuf, uint16_t aLength) const
641 {
642     uint8_t *bufPtr = reinterpret_cast<uint8_t *>(aBuf);
643     Chunk    chunk;
644 
645     GetFirstChunk(aOffset, aLength, chunk);
646 
647     while (chunk.GetLength() > 0)
648     {
649         chunk.CopyBytesTo(bufPtr);
650         bufPtr += chunk.GetLength();
651         GetNextChunk(aLength, chunk);
652     }
653 
654     return static_cast<uint16_t>(bufPtr - reinterpret_cast<uint8_t *>(aBuf));
655 }
656 
ReadBytes(const OffsetRange & aOffsetRange,void * aBuf) const657 uint16_t Message::ReadBytes(const OffsetRange &aOffsetRange, void *aBuf) const
658 {
659     return ReadBytes(aOffsetRange.GetOffset(), aBuf, aOffsetRange.GetLength());
660 }
661 
Read(uint16_t aOffset,void * aBuf,uint16_t aLength) const662 Error Message::Read(uint16_t aOffset, void *aBuf, uint16_t aLength) const
663 {
664     return (ReadBytes(aOffset, aBuf, aLength) == aLength) ? kErrorNone : kErrorParse;
665 }
666 
Read(const OffsetRange & aOffsetRange,void * aBuf,uint16_t aLength) const667 Error Message::Read(const OffsetRange &aOffsetRange, void *aBuf, uint16_t aLength) const
668 {
669     Error error = kErrorNone;
670 
671     VerifyOrExit(aOffsetRange.Contains(aLength), error = kErrorParse);
672     VerifyOrExit(ReadBytes(aOffsetRange.GetOffset(), aBuf, aLength) == aLength, error = kErrorParse);
673 
674 exit:
675     return error;
676 }
677 
CompareBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength,ByteMatcher aMatcher) const678 bool Message::CompareBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength, ByteMatcher aMatcher) const
679 {
680     uint16_t       bytesToCompare = aLength;
681     const uint8_t *bufPtr         = reinterpret_cast<const uint8_t *>(aBuf);
682     Chunk          chunk;
683 
684     GetFirstChunk(aOffset, aLength, chunk);
685 
686     while (chunk.GetLength() > 0)
687     {
688         VerifyOrExit(chunk.MatchesBytesIn(bufPtr, aMatcher));
689         bufPtr += chunk.GetLength();
690         bytesToCompare -= chunk.GetLength();
691         GetNextChunk(aLength, chunk);
692     }
693 
694 exit:
695     return (bytesToCompare == 0);
696 }
697 
CompareBytes(uint16_t aOffset,const Message & aOtherMessage,uint16_t aOtherOffset,uint16_t aLength,ByteMatcher aMatcher) const698 bool Message::CompareBytes(uint16_t       aOffset,
699                            const Message &aOtherMessage,
700                            uint16_t       aOtherOffset,
701                            uint16_t       aLength,
702                            ByteMatcher    aMatcher) const
703 {
704     uint16_t bytesToCompare = aLength;
705     Chunk    chunk;
706 
707     GetFirstChunk(aOffset, aLength, chunk);
708 
709     while (chunk.GetLength() > 0)
710     {
711         VerifyOrExit(aOtherMessage.CompareBytes(aOtherOffset, chunk.GetBytes(), chunk.GetLength(), aMatcher));
712         aOtherOffset += chunk.GetLength();
713         bytesToCompare -= chunk.GetLength();
714         GetNextChunk(aLength, chunk);
715     }
716 
717 exit:
718     return (bytesToCompare == 0);
719 }
720 
WriteBytes(uint16_t aOffset,const void * aBuf,uint16_t aLength)721 void Message::WriteBytes(uint16_t aOffset, const void *aBuf, uint16_t aLength)
722 {
723     const uint8_t *bufPtr = reinterpret_cast<const uint8_t *>(aBuf);
724     MutableChunk   chunk;
725 
726     OT_ASSERT(aOffset + aLength <= GetLength());
727 
728     GetFirstChunk(aOffset, aLength, chunk);
729 
730     while (chunk.GetLength() > 0)
731     {
732         memmove(chunk.GetBytes(), bufPtr, chunk.GetLength());
733         bufPtr += chunk.GetLength();
734         GetNextChunk(aLength, chunk);
735     }
736 }
737 
WriteBytesFromMessage(uint16_t aWriteOffset,const Message & aMessage,uint16_t aReadOffset,uint16_t aLength)738 void Message::WriteBytesFromMessage(uint16_t       aWriteOffset,
739                                     const Message &aMessage,
740                                     uint16_t       aReadOffset,
741                                     uint16_t       aLength)
742 {
743     if ((&aMessage != this) || (aReadOffset >= aWriteOffset))
744     {
745         Chunk chunk;
746 
747         aMessage.GetFirstChunk(aReadOffset, aLength, chunk);
748 
749         while (chunk.GetLength() > 0)
750         {
751             WriteBytes(aWriteOffset, chunk.GetBytes(), chunk.GetLength());
752             aWriteOffset += chunk.GetLength();
753             aMessage.GetNextChunk(aLength, chunk);
754         }
755     }
756     else
757     {
758         // We are copying bytes within the same message forward.
759         // To ensure copy forward works, we read and write from
760         // end of range and move backwards.
761 
762         static constexpr uint16_t kBufSize = 32;
763 
764         uint8_t buf[kBufSize];
765 
766         aWriteOffset += aLength;
767         aReadOffset += aLength;
768 
769         while (aLength > 0)
770         {
771             uint16_t copyLength = Min(kBufSize, aLength);
772 
773             aLength -= copyLength;
774             aReadOffset -= copyLength;
775             aWriteOffset -= copyLength;
776 
777             ReadBytes(aReadOffset, buf, copyLength);
778             WriteBytes(aWriteOffset, buf, copyLength);
779         }
780     }
781 }
782 
Clone(uint16_t aLength) const783 Message *Message::Clone(uint16_t aLength) const
784 {
785     Error    error = kErrorNone;
786     Message *messageCopy;
787     Settings settings(IsLinkSecurityEnabled() ? kWithLinkSecurity : kNoLinkSecurity, GetPriority());
788     uint16_t offset;
789 
790     aLength     = Min(GetLength(), aLength);
791     messageCopy = GetMessagePool()->Allocate(GetType(), GetReserved(), settings);
792     VerifyOrExit(messageCopy != nullptr, error = kErrorNoBufs);
793     SuccessOrExit(error = messageCopy->AppendBytesFromMessage(*this, 0, aLength));
794 
795     // Copy selected message information.
796 
797     offset = Min(GetOffset(), aLength);
798     messageCopy->SetOffset(offset);
799 
800     messageCopy->SetSubType(GetSubType());
801     messageCopy->SetLoopbackToHostAllowed(IsLoopbackToHostAllowed());
802     messageCopy->SetOrigin(GetOrigin());
803     messageCopy->SetTimestamp(GetTimestamp());
804     messageCopy->SetMeshDest(GetMeshDest());
805     messageCopy->SetPanId(GetPanId());
806     messageCopy->SetChannel(GetChannel());
807     messageCopy->SetRssAverager(GetRssAverager());
808     messageCopy->SetLqiAverager(GetLqiAverager());
809 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
810     messageCopy->SetTimeSync(IsTimeSync());
811 #endif
812 
813 exit:
814     FreeAndNullMessageOnError(messageCopy, error);
815     return messageCopy;
816 }
817 
818 #if OPENTHREAD_FTD
GetChildMask(uint16_t aChildIndex) const819 bool Message::GetChildMask(uint16_t aChildIndex) const { return GetMetadata().mChildMask.Get(aChildIndex); }
820 
ClearChildMask(uint16_t aChildIndex)821 void Message::ClearChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, false); }
822 
SetChildMask(uint16_t aChildIndex)823 void Message::SetChildMask(uint16_t aChildIndex) { GetMetadata().mChildMask.Set(aChildIndex, true); }
824 
IsChildPending(void) const825 bool Message::IsChildPending(void) const { return GetMetadata().mChildMask.HasAny(); }
826 #endif
827 
GetLinkInfo(ThreadLinkInfo & aLinkInfo) const828 Error Message::GetLinkInfo(ThreadLinkInfo &aLinkInfo) const
829 {
830     Error error = kErrorNone;
831 
832     VerifyOrExit(IsOriginThreadNetif(), error = kErrorNotFound);
833 
834     aLinkInfo.Clear();
835 
836     aLinkInfo.mPanId               = GetPanId();
837     aLinkInfo.mChannel             = GetChannel();
838     aLinkInfo.mRss                 = GetAverageRss();
839     aLinkInfo.mLqi                 = GetAverageLqi();
840     aLinkInfo.mLinkSecurity        = IsLinkSecurityEnabled();
841     aLinkInfo.mIsDstPanIdBroadcast = IsDstPanIdBroadcast();
842 
843 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
844     aLinkInfo.mTimeSyncSeq       = GetTimeSyncSeq();
845     aLinkInfo.mNetworkTimeOffset = GetNetworkTimeOffset();
846 #endif
847 
848 #if OPENTHREAD_CONFIG_MULTI_RADIO
849     aLinkInfo.mRadioType = GetRadioType();
850 #endif
851 
852 exit:
853     return error;
854 }
855 
UpdateLinkInfoFrom(const ThreadLinkInfo & aLinkInfo)856 void Message::UpdateLinkInfoFrom(const ThreadLinkInfo &aLinkInfo)
857 {
858     SetPanId(aLinkInfo.mPanId);
859     SetChannel(aLinkInfo.mChannel);
860     AddRss(aLinkInfo.mRss);
861     AddLqi(aLinkInfo.mLqi);
862     SetLinkSecurityEnabled(aLinkInfo.mLinkSecurity);
863     GetMetadata().mIsDstPanIdBroadcast = aLinkInfo.IsDstPanIdBroadcast();
864 
865 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
866     SetTimeSyncSeq(aLinkInfo.mTimeSyncSeq);
867     SetNetworkTimeOffset(aLinkInfo.mNetworkTimeOffset);
868 #endif
869 
870 #if OPENTHREAD_CONFIG_MULTI_RADIO
871     SetRadioType(static_cast<Mac::RadioType>(aLinkInfo.mRadioType));
872 #endif
873 }
874 
IsTimeSync(void) const875 bool Message::IsTimeSync(void) const
876 {
877 #if OPENTHREAD_CONFIG_TIME_SYNC_ENABLE
878     return GetMetadata().mTimeSync;
879 #else
880     return false;
881 #endif
882 }
883 
SetMessageQueue(MessageQueue * aMessageQueue)884 void Message::SetMessageQueue(MessageQueue *aMessageQueue)
885 {
886     GetMetadata().mQueue       = aMessageQueue;
887     GetMetadata().mInPriorityQ = false;
888 }
889 
SetPriorityQueue(PriorityQueue * aPriorityQueue)890 void Message::SetPriorityQueue(PriorityQueue *aPriorityQueue)
891 {
892     GetMetadata().mQueue       = aPriorityQueue;
893     GetMetadata().mInPriorityQ = true;
894 }
895 
896 //---------------------------------------------------------------------------------------------------------------------
897 // MessageQueue
898 
Enqueue(Message & aMessage,QueuePosition aPosition)899 void MessageQueue::Enqueue(Message &aMessage, QueuePosition aPosition)
900 {
901     OT_ASSERT(!aMessage.IsInAQueue());
902     OT_ASSERT((aMessage.Next() == nullptr) && (aMessage.Prev() == nullptr));
903 
904     aMessage.SetMessageQueue(this);
905 
906     if (GetTail() == nullptr)
907     {
908         aMessage.Next() = &aMessage;
909         aMessage.Prev() = &aMessage;
910 
911         SetTail(&aMessage);
912     }
913     else
914     {
915         Message *head = GetTail()->Next();
916 
917         aMessage.Next() = head;
918         aMessage.Prev() = GetTail();
919 
920         head->Prev()      = &aMessage;
921         GetTail()->Next() = &aMessage;
922 
923         if (aPosition == kQueuePositionTail)
924         {
925             SetTail(&aMessage);
926         }
927     }
928 }
929 
Dequeue(Message & aMessage)930 void MessageQueue::Dequeue(Message &aMessage)
931 {
932     OT_ASSERT(aMessage.GetMessageQueue() == this);
933     OT_ASSERT((aMessage.Next() != nullptr) && (aMessage.Prev() != nullptr));
934 
935     if (&aMessage == GetTail())
936     {
937         SetTail(GetTail()->Prev());
938 
939         if (&aMessage == GetTail())
940         {
941             SetTail(nullptr);
942         }
943     }
944 
945     aMessage.Prev()->Next() = aMessage.Next();
946     aMessage.Next()->Prev() = aMessage.Prev();
947 
948     aMessage.Prev() = nullptr;
949     aMessage.Next() = nullptr;
950 
951     aMessage.SetMessageQueue(nullptr);
952 }
953 
DequeueAndFree(Message & aMessage)954 void MessageQueue::DequeueAndFree(Message &aMessage)
955 {
956     Dequeue(aMessage);
957     aMessage.Free();
958 }
959 
DequeueAndFreeAll(void)960 void MessageQueue::DequeueAndFreeAll(void)
961 {
962     Message *message;
963 
964     while ((message = GetHead()) != nullptr)
965     {
966         DequeueAndFree(*message);
967     }
968 }
969 
begin(void)970 Message::Iterator MessageQueue::begin(void) { return Message::Iterator(GetHead()); }
971 
begin(void) const972 Message::ConstIterator MessageQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
973 
GetInfo(Info & aInfo) const974 void MessageQueue::GetInfo(Info &aInfo) const
975 {
976     for (const Message &message : *this)
977     {
978         aInfo.mNumMessages++;
979         aInfo.mNumBuffers += message.GetBufferCount();
980         aInfo.mTotalBytes += message.GetLength();
981     }
982 }
983 
984 //---------------------------------------------------------------------------------------------------------------------
985 // PriorityQueue
986 
FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const987 const Message *PriorityQueue::FindFirstNonNullTail(Message::Priority aStartPriorityLevel) const
988 {
989     // Find the first non-`nullptr` tail starting from the given priority
990     // level and moving forward (wrapping from priority value
991     // `kNumPriorities` -1 back to 0).
992 
993     const Message *tail = nullptr;
994     uint8_t        priority;
995 
996     priority = static_cast<uint8_t>(aStartPriorityLevel);
997 
998     do
999     {
1000         if (mTails[priority] != nullptr)
1001         {
1002             tail = mTails[priority];
1003             break;
1004         }
1005 
1006         priority = PrevPriority(priority);
1007     } while (priority != aStartPriorityLevel);
1008 
1009     return tail;
1010 }
1011 
GetHead(void) const1012 const Message *PriorityQueue::GetHead(void) const
1013 {
1014     return Message::NextOf(FindFirstNonNullTail(Message::kPriorityLow));
1015 }
1016 
GetHeadForPriority(Message::Priority aPriority) const1017 const Message *PriorityQueue::GetHeadForPriority(Message::Priority aPriority) const
1018 {
1019     const Message *head;
1020     const Message *previousTail;
1021 
1022     if (mTails[aPriority] != nullptr)
1023     {
1024         previousTail = FindFirstNonNullTail(static_cast<Message::Priority>(PrevPriority(aPriority)));
1025 
1026         OT_ASSERT(previousTail != nullptr);
1027 
1028         head = previousTail->Next();
1029     }
1030     else
1031     {
1032         head = nullptr;
1033     }
1034 
1035     return head;
1036 }
1037 
GetTail(void) const1038 const Message *PriorityQueue::GetTail(void) const { return FindFirstNonNullTail(Message::kPriorityLow); }
1039 
Enqueue(Message & aMessage)1040 void PriorityQueue::Enqueue(Message &aMessage)
1041 {
1042     Message::Priority priority;
1043     Message          *tail;
1044     Message          *next;
1045 
1046     OT_ASSERT(!aMessage.IsInAQueue());
1047 
1048     aMessage.SetPriorityQueue(this);
1049 
1050     priority = aMessage.GetPriority();
1051 
1052     tail = FindFirstNonNullTail(priority);
1053 
1054     if (tail != nullptr)
1055     {
1056         next = tail->Next();
1057 
1058         aMessage.Next() = next;
1059         aMessage.Prev() = tail;
1060         next->Prev()    = &aMessage;
1061         tail->Next()    = &aMessage;
1062     }
1063     else
1064     {
1065         aMessage.Next() = &aMessage;
1066         aMessage.Prev() = &aMessage;
1067     }
1068 
1069     mTails[priority] = &aMessage;
1070 }
1071 
Dequeue(Message & aMessage)1072 void PriorityQueue::Dequeue(Message &aMessage)
1073 {
1074     Message::Priority priority;
1075     Message          *tail;
1076 
1077     OT_ASSERT(aMessage.GetPriorityQueue() == this);
1078 
1079     priority = aMessage.GetPriority();
1080 
1081     tail = mTails[priority];
1082 
1083     if (&aMessage == tail)
1084     {
1085         tail = tail->Prev();
1086 
1087         if ((&aMessage == tail) || (tail->GetPriority() != priority))
1088         {
1089             tail = nullptr;
1090         }
1091 
1092         mTails[priority] = tail;
1093     }
1094 
1095     aMessage.Next()->Prev() = aMessage.Prev();
1096     aMessage.Prev()->Next() = aMessage.Next();
1097     aMessage.Next()         = nullptr;
1098     aMessage.Prev()         = nullptr;
1099 
1100     aMessage.SetPriorityQueue(nullptr);
1101 }
1102 
DequeueAndFree(Message & aMessage)1103 void PriorityQueue::DequeueAndFree(Message &aMessage)
1104 {
1105     Dequeue(aMessage);
1106     aMessage.Free();
1107 }
1108 
DequeueAndFreeAll(void)1109 void PriorityQueue::DequeueAndFreeAll(void)
1110 {
1111     Message *message;
1112 
1113     while ((message = GetHead()) != nullptr)
1114     {
1115         DequeueAndFree(*message);
1116     }
1117 }
1118 
begin(void)1119 Message::Iterator PriorityQueue::begin(void) { return Message::Iterator(GetHead()); }
1120 
begin(void) const1121 Message::ConstIterator PriorityQueue::begin(void) const { return Message::ConstIterator(GetHead()); }
1122 
GetInfo(Info & aInfo) const1123 void PriorityQueue::GetInfo(Info &aInfo) const
1124 {
1125     for (const Message &message : *this)
1126     {
1127         aInfo.mNumMessages++;
1128         aInfo.mNumBuffers += message.GetBufferCount();
1129         aInfo.mTotalBytes += message.GetLength();
1130     }
1131 }
1132 
1133 } // namespace ot
1134 #endif // OPENTHREAD_MTD || OPENTHREAD_FTD
1135