1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements generating and processing of DNS headers and helper functions/methods.
32  */
33 
34 #include "dns_types.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/num_utils.hpp"
39 #include "common/numeric_limits.hpp"
40 #include "common/random.hpp"
41 #include "common/string.hpp"
42 #include "instance/instance.hpp"
43 
44 namespace ot {
45 namespace Dns {
46 
SetRandomMessageId(void)47 Error Header::SetRandomMessageId(void)
48 {
49     return Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&mMessageId), sizeof(mMessageId));
50 }
51 
ResponseCodeToError(Response aResponse)52 Error Header::ResponseCodeToError(Response aResponse)
53 {
54     Error error = kErrorFailed;
55 
56     switch (aResponse)
57     {
58     case kResponseSuccess:
59         error = kErrorNone;
60         break;
61 
62     case kResponseFormatError:   // Server unable to interpret request due to format error.
63     case kResponseBadName:       // Bad name.
64     case kResponseBadTruncation: // Bad truncation.
65     case kResponseNotZone:       // A name is not in the zone.
66         error = kErrorParse;
67         break;
68 
69     case kResponseServerFailure: // Server encountered an internal failure.
70         error = kErrorFailed;
71         break;
72 
73     case kResponseNameError:       // Name that ought to exist, does not exists.
74     case kResponseRecordNotExists: // Some RRset that ought to exist, does not exist.
75         error = kErrorNotFound;
76         break;
77 
78     case kResponseNotImplemented: // Server does not support the query type (OpCode).
79     case kDsoTypeNotImplemented:  // DSO TLV type is not implemented.
80         error = kErrorNotImplemented;
81         break;
82 
83     case kResponseBadAlg: // Bad algorithm.
84         error = kErrorNotCapable;
85         break;
86 
87     case kResponseNameExists:   // Some name that ought not to exist, does exist.
88     case kResponseRecordExists: // Some RRset that ought not to exist, does exist.
89         error = kErrorDuplicated;
90         break;
91 
92     case kResponseRefused: // Server refused to perform operation for policy or security reasons.
93     case kResponseNotAuth: // Service is not authoritative for zone.
94         error = kErrorSecurity;
95         break;
96 
97     default:
98         break;
99     }
100 
101     return error;
102 }
103 
Matches(const char * aFirstLabel,const char * aLabels,const char * aDomain) const104 bool Name::Matches(const char *aFirstLabel, const char *aLabels, const char *aDomain) const
105 {
106     bool matches = false;
107 
108     VerifyOrExit(!IsEmpty());
109 
110     if (IsFromCString())
111     {
112         const char *namePtr = mString;
113 
114         if (aFirstLabel != nullptr)
115         {
116             matches = CompareAndSkipLabels(namePtr, aFirstLabel, kLabelSeparatorChar);
117             VerifyOrExit(matches);
118         }
119 
120         matches = CompareAndSkipLabels(namePtr, aLabels, kLabelSeparatorChar);
121         VerifyOrExit(matches);
122 
123         matches = CompareAndSkipLabels(namePtr, aDomain, kNullChar);
124     }
125     else
126     {
127         uint16_t offset = mOffset;
128 
129         if (aFirstLabel != nullptr)
130         {
131             SuccessOrExit(CompareLabel(*mMessage, offset, aFirstLabel));
132         }
133 
134         SuccessOrExit(CompareMultipleLabels(*mMessage, offset, aLabels));
135         SuccessOrExit(CompareName(*mMessage, offset, aDomain));
136         matches = true;
137     }
138 
139 exit:
140     return matches;
141 }
142 
CompareAndSkipLabels(const char * & aNamePtr,const char * aLabels,char aExpectedNextChar)143 bool Name::CompareAndSkipLabels(const char *&aNamePtr, const char *aLabels, char aExpectedNextChar)
144 {
145     // Compares `aNamePtr` to the label string `aLabels` followed by
146     // the `aExpectedNextChar`(using case-insensitive match). Upon
147     // successful comparison, `aNamePtr` is advanced to point after
148     // the matched portion.
149 
150     bool     matches = false;
151     uint16_t len     = StringLength(aLabels, kMaxNameSize);
152 
153     VerifyOrExit(len < kMaxNameSize);
154 
155     VerifyOrExit(StringStartsWith(aNamePtr, aLabels, kStringCaseInsensitiveMatch));
156     aNamePtr += len;
157 
158     VerifyOrExit(*aNamePtr == aExpectedNextChar);
159     aNamePtr++;
160 
161     matches = true;
162 
163 exit:
164     return matches;
165 }
166 
AppendTo(Message & aMessage) const167 Error Name::AppendTo(Message &aMessage) const
168 {
169     Error error;
170 
171     if (IsEmpty())
172     {
173         error = AppendTerminator(aMessage);
174     }
175     else if (IsFromCString())
176     {
177         error = AppendName(GetAsCString(), aMessage);
178     }
179     else
180     {
181         // Name is from a message. Read labels one by one from
182         // `mMessage` and and append each to the `aMessage`.
183 
184         LabelIterator iterator(*mMessage, mOffset);
185 
186         while (true)
187         {
188             error = iterator.GetNextLabel();
189 
190             switch (error)
191             {
192             case kErrorNone:
193                 SuccessOrExit(error = iterator.AppendLabel(aMessage));
194                 break;
195 
196             case kErrorNotFound:
197                 // We reached the end of name successfully.
198                 error = AppendTerminator(aMessage);
199 
200                 OT_FALL_THROUGH;
201 
202             default:
203                 ExitNow();
204             }
205         }
206     }
207 
208 exit:
209     return error;
210 }
211 
AppendLabel(const char * aLabel,Message & aMessage)212 Error Name::AppendLabel(const char *aLabel, Message &aMessage)
213 {
214     return AppendLabel(aLabel, static_cast<uint8_t>(StringLength(aLabel, kMaxLabelSize)), aMessage);
215 }
216 
AppendLabel(const char * aLabel,uint8_t aLength,Message & aMessage)217 Error Name::AppendLabel(const char *aLabel, uint8_t aLength, Message &aMessage)
218 {
219     Error error = kErrorNone;
220 
221     VerifyOrExit((0 < aLength) && (aLength <= kMaxLabelLength), error = kErrorInvalidArgs);
222 
223     SuccessOrExit(error = aMessage.Append(aLength));
224     error = aMessage.AppendBytes(aLabel, aLength);
225 
226 exit:
227     return error;
228 }
229 
AppendMultipleLabels(const char * aLabels,Message & aMessage)230 Error Name::AppendMultipleLabels(const char *aLabels, Message &aMessage)
231 {
232     Error    error           = kErrorNone;
233     uint16_t index           = 0;
234     uint16_t labelStartIndex = 0;
235     char     ch;
236 
237     VerifyOrExit(aLabels != nullptr);
238 
239     do
240     {
241         ch = aLabels[index];
242 
243         if ((ch == kNullChar) || (ch == kLabelSeparatorChar))
244         {
245             uint8_t labelLength = static_cast<uint8_t>(index - labelStartIndex);
246 
247             if (labelLength == 0)
248             {
249                 // Empty label (e.g., consecutive dots) is invalid, but we
250                 // allow for two cases: (1) where `aLabels` ends with a dot
251                 // (`labelLength` is zero but we are at end of `aLabels` string
252                 // and `ch` is null char. (2) if `aLabels` is just "." (we
253                 // see a dot at index 0, and index 1 is null char).
254 
255                 error =
256                     ((ch == kNullChar) || ((index == 0) && (aLabels[1] == kNullChar))) ? kErrorNone : kErrorInvalidArgs;
257                 ExitNow();
258             }
259 
260             VerifyOrExit(index + 1 < kMaxEncodedLength, error = kErrorInvalidArgs);
261             SuccessOrExit(error = AppendLabel(&aLabels[labelStartIndex], labelLength, aMessage));
262 
263             labelStartIndex = index + 1;
264         }
265 
266         index++;
267 
268     } while (ch != kNullChar);
269 
270 exit:
271     return error;
272 }
273 
AppendTerminator(Message & aMessage)274 Error Name::AppendTerminator(Message &aMessage)
275 {
276     uint8_t terminator = 0;
277 
278     return aMessage.Append(terminator);
279 }
280 
AppendPointerLabel(uint16_t aOffset,Message & aMessage)281 Error Name::AppendPointerLabel(uint16_t aOffset, Message &aMessage)
282 {
283     Error    error;
284     uint16_t value;
285 
286 #if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
287     if (!Instance::IsDnsNameCompressionEnabled())
288     {
289         // If "DNS name compression" mode is disabled, instead of
290         // appending the pointer label, read the name from the message
291         // and append it uncompressed. Note that the `aOffset` parameter
292         // in this method is given relative to the start of DNS header
293         // in `aMessage` (which `aMessage.GetOffset()` specifies).
294 
295         error = Name(aMessage, aOffset + aMessage.GetOffset()).AppendTo(aMessage);
296         ExitNow();
297     }
298 #endif
299 
300     // A pointer label takes the form of a two byte sequence as a
301     // `uint16_t` value. The first two bits are ones. This allows a
302     // pointer to be distinguished from a text label, since the text
303     // label must begin with two zero bits (note that labels are
304     // restricted to 63 octets or less). The next 14-bits specify
305     // an offset value relative to start of DNS header.
306 
307     OT_ASSERT(aOffset < kPointerLabelTypeUint16);
308 
309     value = BigEndian::HostSwap16(aOffset | kPointerLabelTypeUint16);
310 
311     ExitNow(error = aMessage.Append(value));
312 
313 exit:
314     return error;
315 }
316 
AppendName(const char * aName,Message & aMessage)317 Error Name::AppendName(const char *aName, Message &aMessage)
318 {
319     Error error;
320 
321     SuccessOrExit(error = AppendMultipleLabels(aName, aMessage));
322     error = AppendTerminator(aMessage);
323 
324 exit:
325     return error;
326 }
327 
ParseName(const Message & aMessage,uint16_t & aOffset)328 Error Name::ParseName(const Message &aMessage, uint16_t &aOffset)
329 {
330     Error         error;
331     LabelIterator iterator(aMessage, aOffset);
332 
333     while (true)
334     {
335         error = iterator.GetNextLabel();
336 
337         switch (error)
338         {
339         case kErrorNone:
340             break;
341 
342         case kErrorNotFound:
343             // We reached the end of name successfully.
344             aOffset = iterator.mNameEndOffset;
345             error   = kErrorNone;
346 
347             OT_FALL_THROUGH;
348 
349         default:
350             ExitNow();
351         }
352     }
353 
354 exit:
355     return error;
356 }
357 
ReadLabel(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t & aLabelLength)358 Error Name::ReadLabel(const Message &aMessage, uint16_t &aOffset, char *aLabelBuffer, uint8_t &aLabelLength)
359 {
360     Error         error;
361     LabelIterator iterator(aMessage, aOffset);
362 
363     SuccessOrExit(error = iterator.GetNextLabel());
364     SuccessOrExit(error = iterator.ReadLabel(aLabelBuffer, aLabelLength, /* aAllowDotCharInLabel */ true));
365     aOffset = iterator.mNextLabelOffset;
366 
367 exit:
368     return error;
369 }
370 
ReadName(const Message & aMessage,uint16_t & aOffset,char * aNameBuffer,uint16_t aNameBufferSize)371 Error Name::ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize)
372 {
373     Error         error;
374     LabelIterator iterator(aMessage, aOffset);
375     bool          firstLabel = true;
376     uint8_t       labelLength;
377 
378     while (true)
379     {
380         error = iterator.GetNextLabel();
381 
382         switch (error)
383         {
384         case kErrorNone:
385 
386             if (!firstLabel)
387             {
388                 *aNameBuffer++ = kLabelSeparatorChar;
389                 aNameBufferSize--;
390 
391                 // No need to check if we have reached end of the name buffer
392                 // here since `iterator.ReadLabel()` would verify it.
393             }
394 
395             labelLength = static_cast<uint8_t>(Min(static_cast<uint16_t>(kMaxLabelSize), aNameBufferSize));
396             SuccessOrExit(error = iterator.ReadLabel(aNameBuffer, labelLength, /* aAllowDotCharInLabel */ firstLabel));
397             aNameBuffer += labelLength;
398             aNameBufferSize -= labelLength;
399             firstLabel = false;
400             break;
401 
402         case kErrorNotFound:
403             // We reach the end of name successfully. Always add a terminating dot
404             // at the end.
405             *aNameBuffer++ = kLabelSeparatorChar;
406             aNameBufferSize--;
407             VerifyOrExit(aNameBufferSize >= sizeof(uint8_t), error = kErrorNoBufs);
408             *aNameBuffer = kNullChar;
409             aOffset      = iterator.mNameEndOffset;
410             error        = kErrorNone;
411 
412             OT_FALL_THROUGH;
413 
414         default:
415             ExitNow();
416         }
417     }
418 
419 exit:
420     return error;
421 }
422 
CompareLabel(const Message & aMessage,uint16_t & aOffset,const char * aLabel)423 Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel)
424 {
425     Error         error;
426     LabelIterator iterator(aMessage, aOffset);
427 
428     SuccessOrExit(error = iterator.GetNextLabel());
429     VerifyOrExit(iterator.CompareLabel(aLabel, kIsSingleLabel), error = kErrorNotFound);
430     aOffset = iterator.mNextLabelOffset;
431 
432 exit:
433     return error;
434 }
435 
CompareMultipleLabels(const Message & aMessage,uint16_t & aOffset,const char * aLabels)436 Error Name::CompareMultipleLabels(const Message &aMessage, uint16_t &aOffset, const char *aLabels)
437 {
438     Error         error;
439     LabelIterator iterator(aMessage, aOffset);
440 
441     while (true)
442     {
443         SuccessOrExit(error = iterator.GetNextLabel());
444         VerifyOrExit(iterator.CompareLabel(aLabels, !kIsSingleLabel), error = kErrorNotFound);
445 
446         if (*aLabels == kNullChar)
447         {
448             aOffset = iterator.mNextLabelOffset;
449             ExitNow();
450         }
451     }
452 
453 exit:
454     return error;
455 }
456 
CompareName(const Message & aMessage,uint16_t & aOffset,const char * aName)457 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
458 {
459     Error         error;
460     LabelIterator iterator(aMessage, aOffset);
461     bool          matches = true;
462 
463     if (*aName == kLabelSeparatorChar)
464     {
465         aName++;
466         VerifyOrExit(*aName == kNullChar, error = kErrorInvalidArgs);
467     }
468 
469     while (true)
470     {
471         error = iterator.GetNextLabel();
472 
473         switch (error)
474         {
475         case kErrorNone:
476             if (matches && !iterator.CompareLabel(aName, !kIsSingleLabel))
477             {
478                 matches = false;
479             }
480 
481             break;
482 
483         case kErrorNotFound:
484             // We reached the end of the name in `aMessage`. We check if
485             // all the previous labels matched so far, and we are also
486             // at the end of `aName` string (see null char), then we
487             // return `kErrorNone` indicating a successful comparison
488             // (full match). Otherwise we return `kErrorNotFound` to
489             // indicate failed comparison.
490 
491             if (matches && (*aName == kNullChar))
492             {
493                 error = kErrorNone;
494             }
495 
496             aOffset = iterator.mNameEndOffset;
497 
498             OT_FALL_THROUGH;
499 
500         default:
501             ExitNow();
502         }
503     }
504 
505 exit:
506     return error;
507 }
508 
CompareName(const Message & aMessage,uint16_t & aOffset,const Message & aMessage2,uint16_t aOffset2)509 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Message &aMessage2, uint16_t aOffset2)
510 {
511     Error         error;
512     LabelIterator iterator(aMessage, aOffset);
513     LabelIterator iterator2(aMessage2, aOffset2);
514     bool          matches = true;
515 
516     while (true)
517     {
518         error = iterator.GetNextLabel();
519 
520         switch (error)
521         {
522         case kErrorNone:
523             // If all the previous labels matched so far, then verify
524             // that we can get the next label on `iterator2` and that it
525             // matches the label from `iterator`.
526             if (matches && (iterator2.GetNextLabel() != kErrorNone || !iterator.CompareLabel(iterator2)))
527             {
528                 matches = false;
529             }
530 
531             break;
532 
533         case kErrorNotFound:
534             // We reached the end of the name in `aMessage`. We check
535             // that `iterator2` is also at its end, and if all previous
536             // labels matched we return `kErrorNone`.
537 
538             if (matches && (iterator2.GetNextLabel() == kErrorNotFound))
539             {
540                 error = kErrorNone;
541             }
542 
543             aOffset = iterator.mNameEndOffset;
544 
545             OT_FALL_THROUGH;
546 
547         default:
548             ExitNow();
549         }
550     }
551 
552 exit:
553     return error;
554 }
555 
CompareName(const Message & aMessage,uint16_t & aOffset,const Name & aName)556 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Name &aName)
557 {
558     return aName.IsFromCString()
559                ? CompareName(aMessage, aOffset, aName.mString)
560                : (aName.IsFromMessage() ? CompareName(aMessage, aOffset, *aName.mMessage, aName.mOffset)
561                                         : ParseName(aMessage, aOffset));
562 }
563 
GetNextLabel(void)564 Error Name::LabelIterator::GetNextLabel(void)
565 {
566     Error error;
567 
568     while (true)
569     {
570         uint8_t labelLength;
571         uint8_t labelType;
572 
573         SuccessOrExit(error = mMessage.Read(mNextLabelOffset, labelLength));
574 
575         labelType = labelLength & kLabelTypeMask;
576 
577         if (labelType == kTextLabelType)
578         {
579             if (labelLength == 0)
580             {
581                 // Zero label length indicates end of a name.
582 
583                 if (!IsEndOffsetSet())
584                 {
585                     mNameEndOffset = mNextLabelOffset + sizeof(uint8_t);
586                 }
587 
588                 ExitNow(error = kErrorNotFound);
589             }
590 
591             mLabelStartOffset = mNextLabelOffset + sizeof(uint8_t);
592             mLabelLength      = labelLength;
593             mNextLabelOffset  = mLabelStartOffset + labelLength;
594             ExitNow();
595         }
596         else if (labelType == kPointerLabelType)
597         {
598             // A pointer label takes the form of a two byte sequence as a
599             // `uint16_t` value. The first two bits are ones. The next 14 bits
600             // specify an offset value from the start of the DNS header.
601 
602             uint16_t pointerValue;
603             uint16_t nextLabelOffset;
604 
605             SuccessOrExit(error = mMessage.Read(mNextLabelOffset, pointerValue));
606 
607             if (!IsEndOffsetSet())
608             {
609                 mNameEndOffset = mNextLabelOffset + sizeof(uint16_t);
610             }
611 
612             // `mMessage.GetOffset()` must point to the start of the
613             // DNS header.
614             nextLabelOffset = mMessage.GetOffset() + (BigEndian::HostSwap16(pointerValue) & kPointerLabelOffsetMask);
615             VerifyOrExit(nextLabelOffset < mMinLabelOffset, error = kErrorParse);
616             mNextLabelOffset = nextLabelOffset;
617             mMinLabelOffset  = nextLabelOffset;
618 
619             // Go back through the `while(true)` loop to get the next label.
620         }
621         else
622         {
623             ExitNow(error = kErrorParse);
624         }
625     }
626 
627 exit:
628     return error;
629 }
630 
ReadLabel(char * aLabelBuffer,uint8_t & aLabelLength,bool aAllowDotCharInLabel) const631 Error Name::LabelIterator::ReadLabel(char *aLabelBuffer, uint8_t &aLabelLength, bool aAllowDotCharInLabel) const
632 {
633     Error error;
634 
635     VerifyOrExit(mLabelLength < aLabelLength, error = kErrorNoBufs);
636 
637     SuccessOrExit(error = mMessage.Read(mLabelStartOffset, aLabelBuffer, mLabelLength));
638     aLabelBuffer[mLabelLength] = kNullChar;
639     aLabelLength               = mLabelLength;
640 
641     if (!aAllowDotCharInLabel)
642     {
643         VerifyOrExit(StringFind(aLabelBuffer, kLabelSeparatorChar) == nullptr, error = kErrorParse);
644     }
645 
646 exit:
647     return error;
648 }
649 
CaseInsensitiveMatch(uint8_t aFirst,uint8_t aSecond)650 bool Name::LabelIterator::CaseInsensitiveMatch(uint8_t aFirst, uint8_t aSecond)
651 {
652     return ToLowercase(static_cast<char>(aFirst)) == ToLowercase(static_cast<char>(aSecond));
653 }
654 
CompareLabel(const char * & aName,bool aIsSingleLabel) const655 bool Name::LabelIterator::CompareLabel(const char *&aName, bool aIsSingleLabel) const
656 {
657     // This method compares the current label in the iterator with the
658     // `aName` string. `aIsSingleLabel` indicates whether `aName` is a
659     // single label, or a sequence of labels separated by dot '.' char.
660     // If the label matches `aName`, then `aName` pointer is moved
661     // forward to the start of the next label (skipping over the `.`
662     // char). This method returns `true` when the labels match, `false`
663     // otherwise.
664 
665     bool matches = false;
666 
667     VerifyOrExit(StringLength(aName, mLabelLength) == mLabelLength);
668     matches = mMessage.CompareBytes(mLabelStartOffset, aName, mLabelLength, CaseInsensitiveMatch);
669 
670     VerifyOrExit(matches);
671 
672     aName += mLabelLength;
673 
674     // If `aName` is a single label, we should be also at the end of the
675     // `aName` string. Otherwise, we should see either null or dot '.'
676     // character (in case `aName` contains multiple labels).
677 
678     matches = (*aName == kNullChar);
679 
680     if (!aIsSingleLabel && (*aName == kLabelSeparatorChar))
681     {
682         matches = true;
683         aName++;
684     }
685 
686 exit:
687     return matches;
688 }
689 
CompareLabel(const LabelIterator & aOtherIterator) const690 bool Name::LabelIterator::CompareLabel(const LabelIterator &aOtherIterator) const
691 {
692     // This method compares the current label in the iterator with the
693     // label from another iterator.
694 
695     return (mLabelLength == aOtherIterator.mLabelLength) &&
696            mMessage.CompareBytes(mLabelStartOffset, aOtherIterator.mMessage, aOtherIterator.mLabelStartOffset,
697                                  mLabelLength, CaseInsensitiveMatch);
698 }
699 
AppendLabel(Message & aMessage) const700 Error Name::LabelIterator::AppendLabel(Message &aMessage) const
701 {
702     // This method reads and appends the current label in the iterator
703     // to `aMessage`.
704 
705     Error error;
706 
707     VerifyOrExit((0 < mLabelLength) && (mLabelLength <= kMaxLabelLength), error = kErrorInvalidArgs);
708     SuccessOrExit(error = aMessage.Append(mLabelLength));
709     error = aMessage.AppendBytesFromMessage(mMessage, mLabelStartOffset, mLabelLength);
710 
711 exit:
712     return error;
713 }
714 
ExtractLabels(const char * aName,const char * aSuffixName,char * aLabels,uint16_t aLabelsSize)715 Error Name::ExtractLabels(const char *aName, const char *aSuffixName, char *aLabels, uint16_t aLabelsSize)
716 {
717     Error       error        = kErrorParse;
718     uint16_t    nameLength   = StringLength(aName, kMaxNameSize);
719     uint16_t    suffixLength = StringLength(aSuffixName, kMaxNameSize);
720     const char *suffixStart;
721 
722     VerifyOrExit(nameLength < kMaxNameSize);
723     VerifyOrExit(suffixLength < kMaxNameSize);
724 
725     VerifyOrExit(nameLength > suffixLength);
726 
727     suffixStart = aName + nameLength - suffixLength;
728     VerifyOrExit(StringMatch(suffixStart, aSuffixName, kStringCaseInsensitiveMatch));
729     suffixStart--;
730     VerifyOrExit(*suffixStart == kLabelSeparatorChar);
731 
732     // Determine the labels length to copy
733     nameLength -= (suffixLength + 1);
734     VerifyOrExit(nameLength < aLabelsSize, error = kErrorNoBufs);
735 
736     if (aLabels != aName)
737     {
738         memmove(aLabels, aName, nameLength);
739     }
740 
741     aLabels[nameLength] = kNullChar;
742     error               = kErrorNone;
743 
744 exit:
745     return error;
746 }
747 
IsSubDomainOf(const char * aName,const char * aDomain)748 bool Name::IsSubDomainOf(const char *aName, const char *aDomain)
749 {
750     bool     match             = false;
751     bool     nameEndsWithDot   = false;
752     bool     domainEndsWithDot = false;
753     uint16_t nameLength        = StringLength(aName, kMaxNameLength);
754     uint16_t domainLength      = StringLength(aDomain, kMaxNameLength);
755 
756     if (nameLength > 0 && aName[nameLength - 1] == kLabelSeparatorChar)
757     {
758         nameEndsWithDot = true;
759         --nameLength;
760     }
761 
762     if (domainLength > 0 && aDomain[domainLength - 1] == kLabelSeparatorChar)
763     {
764         domainEndsWithDot = true;
765         --domainLength;
766     }
767 
768     VerifyOrExit(nameLength >= domainLength);
769 
770     aName += nameLength - domainLength;
771 
772     if (nameLength > domainLength)
773     {
774         VerifyOrExit(aName[-1] == kLabelSeparatorChar);
775     }
776 
777     // This method allows either `aName` or `aDomain` to include or
778     // exclude the last `.` character. If both include it or if both
779     // do not, we do a full comparison using `StringMatch()`.
780     // Otherwise (i.e., when one includes and the other one does not)
781     // we use `StringStartWith()` to allow the extra `.` character.
782 
783     if (nameEndsWithDot == domainEndsWithDot)
784     {
785         match = StringMatch(aName, aDomain, kStringCaseInsensitiveMatch);
786     }
787     else if (nameEndsWithDot)
788     {
789         // `aName` ends with dot, but `aDomain` does not.
790         match = StringStartsWith(aName, aDomain, kStringCaseInsensitiveMatch);
791     }
792     else
793     {
794         // `aDomain` ends with dot, but `aName` does not.
795         match = StringStartsWith(aDomain, aName, kStringCaseInsensitiveMatch);
796     }
797 
798 exit:
799     return match;
800 }
801 
IsSameDomain(const char * aDomain1,const char * aDomain2)802 bool Name::IsSameDomain(const char *aDomain1, const char *aDomain2)
803 {
804     return IsSubDomainOf(aDomain1, aDomain2) && IsSubDomainOf(aDomain2, aDomain1);
805 }
806 
ParseRecords(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords)807 Error ResourceRecord::ParseRecords(const Message &aMessage, uint16_t &aOffset, uint16_t aNumRecords)
808 {
809     Error error = kErrorNone;
810 
811     while (aNumRecords > 0)
812     {
813         ResourceRecord record;
814 
815         SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
816         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
817         aOffset += static_cast<uint16_t>(record.GetSize());
818         aNumRecords--;
819     }
820 
821 exit:
822     return error;
823 }
824 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t & aNumRecords,const Name & aName)825 Error ResourceRecord::FindRecord(const Message &aMessage, uint16_t &aOffset, uint16_t &aNumRecords, const Name &aName)
826 {
827     Error error;
828 
829     while (aNumRecords > 0)
830     {
831         bool           matches = true;
832         ResourceRecord record;
833 
834         error = Name::CompareName(aMessage, aOffset, aName);
835 
836         switch (error)
837         {
838         case kErrorNone:
839             break;
840         case kErrorNotFound:
841             matches = false;
842             break;
843         default:
844             ExitNow();
845         }
846 
847         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
848         aNumRecords--;
849         VerifyOrExit(!matches);
850         aOffset += static_cast<uint16_t>(record.GetSize());
851     }
852 
853     error = kErrorNotFound;
854 
855 exit:
856     return error;
857 }
858 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords,uint16_t aIndex,const Name & aName,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)859 Error ResourceRecord::FindRecord(const Message  &aMessage,
860                                  uint16_t       &aOffset,
861                                  uint16_t        aNumRecords,
862                                  uint16_t        aIndex,
863                                  const Name     &aName,
864                                  uint16_t        aType,
865                                  ResourceRecord &aRecord,
866                                  uint16_t        aMinRecordSize)
867 {
868     // This static method searches in `aMessage` starting from `aOffset`
869     // up to maximum of `aNumRecords`, for the `(aIndex+1)`th
870     // occurrence of a resource record of type `aType` with record name
871     // matching `aName`. It also verifies that the record size is larger
872     // than `aMinRecordSize`. If found, `aMinRecordSize` bytes from the
873     // record are read and copied into `aRecord`. In this case `aOffset`
874     // is updated to point to the last record byte read from the message
875     // (so that the caller can read any remaining fields in the record
876     // data).
877 
878     Error    error;
879     uint16_t offset = aOffset;
880     uint16_t recordOffset;
881 
882     while (aNumRecords > 0)
883     {
884         SuccessOrExit(error = FindRecord(aMessage, offset, aNumRecords, aName));
885 
886         // Save the offset to start of `ResourceRecord` fields.
887         recordOffset = offset;
888 
889         error = ReadRecord(aMessage, offset, aType, aRecord, aMinRecordSize);
890 
891         if (error == kErrorNotFound)
892         {
893             // `ReadRecord()` already updates the `offset` to skip
894             // over a non-matching record.
895             continue;
896         }
897 
898         SuccessOrExit(error);
899 
900         if (aIndex == 0)
901         {
902             aOffset = offset;
903             ExitNow();
904         }
905 
906         aIndex--;
907 
908         // Skip over the record.
909         offset = static_cast<uint16_t>(recordOffset + aRecord.GetSize());
910     }
911 
912     error = kErrorNotFound;
913 
914 exit:
915     return error;
916 }
917 
ReadRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)918 Error ResourceRecord::ReadRecord(const Message  &aMessage,
919                                  uint16_t       &aOffset,
920                                  uint16_t        aType,
921                                  ResourceRecord &aRecord,
922                                  uint16_t        aMinRecordSize)
923 {
924     // This static method tries to read a matching resource record of a
925     // given type and a minimum record size from a message. The `aType`
926     // value of `kTypeAny` matches any type.  If the record in the
927     // message does not match, it skips over the record. Please see
928     // `ReadRecord<RecordType>()` for more details.
929 
930     Error          error;
931     ResourceRecord record;
932 
933     SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
934 
935     if (((aType == kTypeAny) || (record.GetType() == aType)) && (record.GetSize() >= aMinRecordSize))
936     {
937         IgnoreError(aMessage.Read(aOffset, &aRecord, aMinRecordSize));
938         aOffset += aMinRecordSize;
939     }
940     else
941     {
942         // Skip over the entire record.
943         aOffset += static_cast<uint16_t>(record.GetSize());
944         error = kErrorNotFound;
945     }
946 
947 exit:
948     return error;
949 }
950 
ReadName(const Message & aMessage,uint16_t & aOffset,uint16_t aStartOffset,char * aNameBuffer,uint16_t aNameBufferSize,bool aSkipRecord) const951 Error ResourceRecord::ReadName(const Message &aMessage,
952                                uint16_t      &aOffset,
953                                uint16_t       aStartOffset,
954                                char          *aNameBuffer,
955                                uint16_t       aNameBufferSize,
956                                bool           aSkipRecord) const
957 {
958     // This protected method parses and reads a name field in a record
959     // from a message. It is intended only for sub-classes of
960     // `ResourceRecord`.
961     //
962     // On input `aOffset` gives the offset in `aMessage` to the start of
963     // name field. `aStartOffset` gives the offset to the start of the
964     // `ResourceRecord`. `aSkipRecord` indicates whether to skip over
965     // the entire resource record or just the read name. On exit, when
966     // successfully read, `aOffset` is updated to either point after the
967     // end of record or after the the name field.
968     //
969     // When read successfully, this method returns `kErrorNone`. On a
970     // parse error (invalid format) returns `kErrorParse`. If the
971     // name does not fit in the given name buffer it returns
972     // `kErrorNoBufs`
973 
974     Error error = kErrorNone;
975 
976     SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
977     VerifyOrExit(aOffset <= aStartOffset + GetSize(), error = kErrorParse);
978 
979     VerifyOrExit(aSkipRecord);
980     aOffset = aStartOffset;
981     error   = SkipRecord(aMessage, aOffset);
982 
983 exit:
984     return error;
985 }
986 
SkipRecord(const Message & aMessage,uint16_t & aOffset) const987 Error ResourceRecord::SkipRecord(const Message &aMessage, uint16_t &aOffset) const
988 {
989     // This protected method parses and skips over a resource record
990     // in a message.
991     //
992     // On input `aOffset` gives the offset in `aMessage` to the start of
993     // the `ResourceRecord`. On exit, when successfully parsed, `aOffset`
994     // is updated to point to byte after the entire record.
995 
996     Error error;
997 
998     SuccessOrExit(error = CheckRecord(aMessage, aOffset));
999     aOffset += static_cast<uint16_t>(GetSize());
1000 
1001 exit:
1002     return error;
1003 }
1004 
CheckRecord(const Message & aMessage,uint16_t aOffset) const1005 Error ResourceRecord::CheckRecord(const Message &aMessage, uint16_t aOffset) const
1006 {
1007     // This method checks that the entire record (including record data)
1008     // is present in `aMessage` at `aOffset` (pointing to the start of
1009     // the `ResourceRecord` in `aMessage`).
1010 
1011     return (aOffset + GetSize() <= aMessage.GetLength()) ? kErrorNone : kErrorParse;
1012 }
1013 
ReadFrom(const Message & aMessage,uint16_t aOffset)1014 Error ResourceRecord::ReadFrom(const Message &aMessage, uint16_t aOffset)
1015 {
1016     // This method reads the `ResourceRecord` from `aMessage` at
1017     // `aOffset`. It verifies that the entire record (including record
1018     // data) is present in the message.
1019 
1020     Error error;
1021 
1022     SuccessOrExit(error = aMessage.Read(aOffset, *this));
1023     error = CheckRecord(aMessage, aOffset);
1024 
1025 exit:
1026     return error;
1027 }
1028 
Init(const uint8_t * aTxtData,uint16_t aTxtDataLength)1029 void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength)
1030 {
1031     SetTxtData(aTxtData);
1032     SetTxtDataLength(aTxtDataLength);
1033     SetTxtDataPosition(0);
1034 }
1035 
GetNextEntry(TxtEntry & aEntry)1036 Error TxtEntry::Iterator::GetNextEntry(TxtEntry &aEntry)
1037 {
1038     Error       error = kErrorNone;
1039     uint8_t     length;
1040     uint8_t     index;
1041     const char *cur;
1042     char       *keyBuffer = GetKeyBuffer();
1043 
1044     static_assert(sizeof(mChar) >= TxtEntry::kMaxKeyLength + 1, "KeyBuffer cannot fit the max key length");
1045 
1046     VerifyOrExit(GetTxtData() != nullptr, error = kErrorParse);
1047 
1048     aEntry.mKey = keyBuffer;
1049 
1050     while ((cur = GetTxtData() + GetTxtDataPosition()) < GetTxtDataEnd())
1051     {
1052         length = static_cast<uint8_t>(*cur);
1053 
1054         cur++;
1055         VerifyOrExit(cur + length <= GetTxtDataEnd(), error = kErrorParse);
1056         IncreaseTxtDataPosition(sizeof(uint8_t) + length);
1057 
1058         // Silently skip over an empty string or if the string starts with
1059         // a `=` character (i.e., missing key) - RFC 6763 - section 6.4.
1060 
1061         if ((length == 0) || (cur[0] == kKeyValueSeparator))
1062         {
1063             continue;
1064         }
1065 
1066         for (index = 0; index < length; index++)
1067         {
1068             if (cur[index] == kKeyValueSeparator)
1069             {
1070                 keyBuffer[index++]  = kNullChar; // Increment index to skip over `=`.
1071                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(&cur[index]);
1072                 aEntry.mValueLength = length - index;
1073                 ExitNow();
1074             }
1075 
1076             if (index >= sizeof(mChar) - 1)
1077             {
1078                 // The key is larger than supported key string length.
1079                 // In this case, we return the full encoded string in
1080                 // `mValue` and `mValueLength` and set `mKey` to
1081                 // `nullptr`.
1082 
1083                 aEntry.mKey         = nullptr;
1084                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(cur);
1085                 aEntry.mValueLength = length;
1086                 ExitNow();
1087             }
1088 
1089             keyBuffer[index] = cur[index];
1090         }
1091 
1092         // If we reach the end of the string without finding `=` then
1093         // it is a boolean key attribute (encoded as "key").
1094 
1095         keyBuffer[index]    = kNullChar;
1096         aEntry.mValue       = nullptr;
1097         aEntry.mValueLength = 0;
1098         ExitNow();
1099     }
1100 
1101     error = kErrorNotFound;
1102 
1103 exit:
1104     return error;
1105 }
1106 
AppendTo(Message & aMessage) const1107 Error TxtEntry::AppendTo(Message &aMessage) const
1108 {
1109     Appender appender(aMessage);
1110 
1111     return AppendTo(appender);
1112 }
1113 
AppendTo(Appender & aAppender) const1114 Error TxtEntry::AppendTo(Appender &aAppender) const
1115 {
1116     Error    error = kErrorNone;
1117     uint16_t keyLength;
1118     char     separator = kKeyValueSeparator;
1119 
1120     if (mKey == nullptr)
1121     {
1122         VerifyOrExit((mValue != nullptr) && (mValueLength != 0));
1123         error = aAppender.AppendBytes(mValue, mValueLength);
1124         ExitNow();
1125     }
1126 
1127     keyLength = StringLength(mKey, static_cast<uint16_t>(kMaxKeyValueEncodedSize) + 1);
1128 
1129     VerifyOrExit(kMinKeyLength <= keyLength, error = kErrorInvalidArgs);
1130 
1131     if (mValue == nullptr)
1132     {
1133         // Treat as a boolean attribute and encoded as "key" (with no `=`).
1134         SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength)));
1135         error = aAppender.AppendBytes(mKey, keyLength);
1136         ExitNow();
1137     }
1138 
1139     // Treat as key/value and encode as "key=value", value may be empty.
1140 
1141     VerifyOrExit(mValueLength + keyLength + sizeof(char) <= kMaxKeyValueEncodedSize, error = kErrorInvalidArgs);
1142 
1143     SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength + mValueLength + sizeof(char))));
1144     SuccessOrExit(error = aAppender.AppendBytes(mKey, keyLength));
1145     SuccessOrExit(error = aAppender.Append(separator));
1146     error = aAppender.AppendBytes(mValue, mValueLength);
1147 
1148 exit:
1149     return error;
1150 }
1151 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Message & aMessage)1152 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Message &aMessage)
1153 {
1154     Appender appender(aMessage);
1155 
1156     return AppendEntries(aEntries, aNumEntries, appender);
1157 }
1158 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,MutableData<kWithUint16Length> & aData)1159 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, MutableData<kWithUint16Length> &aData)
1160 {
1161     Error    error;
1162     Appender appender(aData.GetBytes(), aData.GetLength());
1163 
1164     SuccessOrExit(error = AppendEntries(aEntries, aNumEntries, appender));
1165     appender.GetAsData(aData);
1166 
1167 exit:
1168     return error;
1169 }
1170 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Appender & aAppender)1171 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Appender &aAppender)
1172 {
1173     Error error = kErrorNone;
1174 
1175     for (uint16_t index = 0; index < aNumEntries; index++)
1176     {
1177         SuccessOrExit(error = aEntries[index].AppendTo(aAppender));
1178     }
1179 
1180     if (aAppender.GetAppendedLength() == 0)
1181     {
1182         error = aAppender.Append<uint8_t>(0);
1183     }
1184 
1185 exit:
1186     return error;
1187 }
1188 
IsValid(void) const1189 bool AaaaRecord::IsValid(void) const
1190 {
1191     return GetType() == Dns::ResourceRecord::kTypeAaaa && GetSize() == sizeof(*this);
1192 }
1193 
IsValid(void) const1194 bool KeyRecord::IsValid(void) const { return GetType() == Dns::ResourceRecord::kTypeKey; }
1195 
1196 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Init(void)1197 void Ecdsa256KeyRecord::Init(void)
1198 {
1199     KeyRecord::Init();
1200     SetAlgorithm(kAlgorithmEcdsaP256Sha256);
1201 }
1202 
IsValid(void) const1203 bool Ecdsa256KeyRecord::IsValid(void) const
1204 {
1205     return KeyRecord::IsValid() && GetLength() == sizeof(*this) - sizeof(ResourceRecord) &&
1206            GetAlgorithm() == kAlgorithmEcdsaP256Sha256;
1207 }
1208 #endif
1209 
IsValid(void) const1210 bool SigRecord::IsValid(void) const
1211 {
1212     return GetType() == Dns::ResourceRecord::kTypeSig && GetLength() >= sizeof(*this) - sizeof(ResourceRecord);
1213 }
1214 
InitAsShortVariant(uint32_t aLeaseInterval)1215 void LeaseOption::InitAsShortVariant(uint32_t aLeaseInterval)
1216 {
1217     SetOptionCode(kUpdateLease);
1218     SetOptionLength(kShortLength);
1219     SetLeaseInterval(aLeaseInterval);
1220 }
1221 
InitAsLongVariant(uint32_t aLeaseInterval,uint32_t aKeyLeaseInterval)1222 void LeaseOption::InitAsLongVariant(uint32_t aLeaseInterval, uint32_t aKeyLeaseInterval)
1223 {
1224     SetOptionCode(kUpdateLease);
1225     SetOptionLength(kLongLength);
1226     SetLeaseInterval(aLeaseInterval);
1227     SetKeyLeaseInterval(aKeyLeaseInterval);
1228 }
1229 
IsValid(void) const1230 bool LeaseOption::IsValid(void) const
1231 {
1232     bool isValid = false;
1233 
1234     VerifyOrExit((GetOptionLength() == kShortLength) || (GetOptionLength() >= kLongLength));
1235     isValid = (GetLeaseInterval() <= GetKeyLeaseInterval());
1236 
1237 exit:
1238     return isValid;
1239 }
1240 
ReadFrom(const Message & aMessage,uint16_t aOffset,uint16_t aLength)1241 Error LeaseOption::ReadFrom(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
1242 {
1243     Error    error = kErrorNone;
1244     uint16_t endOffset;
1245 
1246     VerifyOrExit(static_cast<uint32_t>(aOffset) + aLength <= aMessage.GetLength(), error = kErrorParse);
1247 
1248     endOffset = aOffset + aLength;
1249 
1250     while (aOffset < endOffset)
1251     {
1252         uint16_t size;
1253 
1254         SuccessOrExit(error = aMessage.Read(aOffset, this, sizeof(Option)));
1255 
1256         VerifyOrExit(aOffset + GetSize() <= endOffset, error = kErrorParse);
1257 
1258         size = static_cast<uint16_t>(GetSize());
1259 
1260         if (GetOptionCode() == kUpdateLease)
1261         {
1262             VerifyOrExit(GetOptionLength() >= kShortLength, error = kErrorParse);
1263 
1264             IgnoreError(aMessage.Read(aOffset, this, Min(size, static_cast<uint16_t>(sizeof(LeaseOption)))));
1265             VerifyOrExit(IsValid(), error = kErrorParse);
1266 
1267             ExitNow();
1268         }
1269 
1270         aOffset += size;
1271     }
1272 
1273     error = kErrorNotFound;
1274 
1275 exit:
1276     return error;
1277 }
1278 
ReadPtrName(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const1279 Error PtrRecord::ReadPtrName(const Message &aMessage,
1280                              uint16_t      &aOffset,
1281                              char          *aLabelBuffer,
1282                              uint8_t        aLabelBufferSize,
1283                              char          *aNameBuffer,
1284                              uint16_t       aNameBufferSize) const
1285 {
1286     Error    error       = kErrorNone;
1287     uint16_t startOffset = aOffset - sizeof(PtrRecord); // start of `PtrRecord`.
1288 
1289     // Verify that the name is within the record data length.
1290     SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
1291     VerifyOrExit(aOffset <= startOffset + GetSize(), error = kErrorParse);
1292 
1293     aOffset = startOffset + sizeof(PtrRecord);
1294     SuccessOrExit(error = Name::ReadLabel(aMessage, aOffset, aLabelBuffer, aLabelBufferSize));
1295 
1296     if (aNameBuffer != nullptr)
1297     {
1298         SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
1299     }
1300 
1301     aOffset = startOffset;
1302     error   = SkipRecord(aMessage, aOffset);
1303 
1304 exit:
1305     return error;
1306 }
1307 
ReadTxtData(const Message & aMessage,uint16_t & aOffset,uint8_t * aTxtBuffer,uint16_t & aTxtBufferSize) const1308 Error TxtRecord::ReadTxtData(const Message &aMessage,
1309                              uint16_t      &aOffset,
1310                              uint8_t       *aTxtBuffer,
1311                              uint16_t      &aTxtBufferSize) const
1312 {
1313     Error error = kErrorNone;
1314 
1315     SuccessOrExit(error = aMessage.Read(aOffset, aTxtBuffer, Min(GetLength(), aTxtBufferSize)));
1316     aOffset += GetLength();
1317 
1318     VerifyOrExit(GetLength() <= aTxtBufferSize, error = kErrorNoBufs);
1319     aTxtBufferSize = GetLength();
1320     VerifyOrExit(VerifyTxtData(aTxtBuffer, aTxtBufferSize, /* aAllowEmpty */ true), error = kErrorParse);
1321 
1322 exit:
1323     return error;
1324 }
1325 
VerifyTxtData(const uint8_t * aTxtData,uint16_t aTxtLength,bool aAllowEmpty)1326 bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength, bool aAllowEmpty)
1327 {
1328     bool    valid          = false;
1329     uint8_t curEntryLength = 0;
1330 
1331     // Per RFC 1035, TXT-DATA MUST have one or more <character-string>s.
1332     VerifyOrExit(aAllowEmpty || aTxtLength > 0);
1333 
1334     for (uint16_t i = 0; i < aTxtLength; ++i)
1335     {
1336         if (curEntryLength == 0)
1337         {
1338             curEntryLength = aTxtData[i];
1339         }
1340         else
1341         {
1342             --curEntryLength;
1343         }
1344     }
1345 
1346     valid = (curEntryLength == 0);
1347 
1348 exit:
1349     return valid;
1350 }
1351 
AddType(uint16_t aType)1352 void NsecRecord::TypeBitMap::AddType(uint16_t aType)
1353 {
1354     if ((aType >> 8) == mBlockNumber)
1355     {
1356         uint8_t  type  = static_cast<uint8_t>(aType & 0xff);
1357         uint8_t  index = (type / kBitsPerByte);
1358         uint16_t mask  = (0x80 >> (type % kBitsPerByte));
1359 
1360         mBitmaps[index] |= mask;
1361         mBitmapLength = Max<uint8_t>(mBitmapLength, index + 1);
1362     }
1363 }
1364 
ContainsType(uint16_t aType) const1365 bool NsecRecord::TypeBitMap::ContainsType(uint16_t aType) const
1366 {
1367     bool     contains = false;
1368     uint8_t  type     = static_cast<uint8_t>(aType & 0xff);
1369     uint8_t  index    = (type / kBitsPerByte);
1370     uint16_t mask     = (0x80 >> (type % kBitsPerByte));
1371 
1372     VerifyOrExit((aType >> 8) == mBlockNumber);
1373 
1374     VerifyOrExit(index < mBitmapLength);
1375 
1376     contains = (mBitmaps[index] & mask);
1377 
1378 exit:
1379     return contains;
1380 }
1381 
1382 } // namespace Dns
1383 } // namespace ot
1384