1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements generating and processing of DNS headers and helper functions/methods.
32  */
33 
34 #include "dns_types.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/instance.hpp"
39 #include "common/num_utils.hpp"
40 #include "common/random.hpp"
41 #include "common/string.hpp"
42 
43 namespace ot {
44 namespace Dns {
45 
46 using ot::Encoding::BigEndian::HostSwap16;
47 
SetRandomMessageId(void)48 Error Header::SetRandomMessageId(void)
49 {
50     return Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&mMessageId), sizeof(mMessageId));
51 }
52 
ResponseCodeToError(Response aResponse)53 Error Header::ResponseCodeToError(Response aResponse)
54 {
55     Error error = kErrorFailed;
56 
57     switch (aResponse)
58     {
59     case kResponseSuccess:
60         error = kErrorNone;
61         break;
62 
63     case kResponseFormatError:   // Server unable to interpret request due to format error.
64     case kResponseBadName:       // Bad name.
65     case kResponseBadTruncation: // Bad truncation.
66     case kResponseNotZone:       // A name is not in the zone.
67         error = kErrorParse;
68         break;
69 
70     case kResponseServerFailure: // Server encountered an internal failure.
71         error = kErrorFailed;
72         break;
73 
74     case kResponseNameError:       // Name that ought to exist, does not exists.
75     case kResponseRecordNotExists: // Some RRset that ought to exist, does not exist.
76         error = kErrorNotFound;
77         break;
78 
79     case kResponseNotImplemented: // Server does not support the query type (OpCode).
80     case kDsoTypeNotImplemented:  // DSO TLV type is not implemented.
81         error = kErrorNotImplemented;
82         break;
83 
84     case kResponseBadAlg: // Bad algorithm.
85         error = kErrorNotCapable;
86         break;
87 
88     case kResponseNameExists:   // Some name that ought not to exist, does exist.
89     case kResponseRecordExists: // Some RRset that ought not to exist, does exist.
90         error = kErrorDuplicated;
91         break;
92 
93     case kResponseRefused: // Server refused to perform operation for policy or security reasons.
94     case kResponseNotAuth: // Service is not authoritative for zone.
95         error = kErrorSecurity;
96         break;
97 
98     default:
99         break;
100     }
101 
102     return error;
103 }
104 
AppendTo(Message & aMessage) const105 Error Name::AppendTo(Message &aMessage) const
106 {
107     Error error;
108 
109     if (IsEmpty())
110     {
111         error = AppendTerminator(aMessage);
112     }
113     else if (IsFromCString())
114     {
115         error = AppendName(GetAsCString(), aMessage);
116     }
117     else
118     {
119         // Name is from a message. Read labels one by one from
120         // `mMessage` and and append each to the `aMessage`.
121 
122         LabelIterator iterator(*mMessage, mOffset);
123 
124         while (true)
125         {
126             error = iterator.GetNextLabel();
127 
128             switch (error)
129             {
130             case kErrorNone:
131                 SuccessOrExit(error = iterator.AppendLabel(aMessage));
132                 break;
133 
134             case kErrorNotFound:
135                 // We reached the end of name successfully.
136                 error = AppendTerminator(aMessage);
137 
138                 OT_FALL_THROUGH;
139 
140             default:
141                 ExitNow();
142             }
143         }
144     }
145 
146 exit:
147     return error;
148 }
149 
AppendLabel(const char * aLabel,Message & aMessage)150 Error Name::AppendLabel(const char *aLabel, Message &aMessage)
151 {
152     return AppendLabel(aLabel, static_cast<uint8_t>(StringLength(aLabel, kMaxLabelSize)), aMessage);
153 }
154 
AppendLabel(const char * aLabel,uint8_t aLength,Message & aMessage)155 Error Name::AppendLabel(const char *aLabel, uint8_t aLength, Message &aMessage)
156 {
157     Error error = kErrorNone;
158 
159     VerifyOrExit((0 < aLength) && (aLength <= kMaxLabelLength), error = kErrorInvalidArgs);
160 
161     SuccessOrExit(error = aMessage.Append(aLength));
162     error = aMessage.AppendBytes(aLabel, aLength);
163 
164 exit:
165     return error;
166 }
167 
AppendMultipleLabels(const char * aLabels,Message & aMessage)168 Error Name::AppendMultipleLabels(const char *aLabels, Message &aMessage)
169 {
170     Error    error           = kErrorNone;
171     uint16_t index           = 0;
172     uint16_t labelStartIndex = 0;
173     char     ch;
174 
175     VerifyOrExit(aLabels != nullptr);
176 
177     do
178     {
179         ch = aLabels[index];
180 
181         if ((ch == kNullChar) || (ch == kLabelSeparatorChar))
182         {
183             uint8_t labelLength = static_cast<uint8_t>(index - labelStartIndex);
184 
185             if (labelLength == 0)
186             {
187                 // Empty label (e.g., consecutive dots) is invalid, but we
188                 // allow for two cases: (1) where `aLabels` ends with a dot
189                 // (`labelLength` is zero but we are at end of `aLabels` string
190                 // and `ch` is null char. (2) if `aLabels` is just "." (we
191                 // see a dot at index 0, and index 1 is null char).
192 
193                 error =
194                     ((ch == kNullChar) || ((index == 0) && (aLabels[1] == kNullChar))) ? kErrorNone : kErrorInvalidArgs;
195                 ExitNow();
196             }
197 
198             VerifyOrExit(index + 1 < kMaxEncodedLength, error = kErrorInvalidArgs);
199             SuccessOrExit(error = AppendLabel(&aLabels[labelStartIndex], labelLength, aMessage));
200 
201             labelStartIndex = index + 1;
202         }
203 
204         index++;
205 
206     } while (ch != kNullChar);
207 
208 exit:
209     return error;
210 }
211 
AppendTerminator(Message & aMessage)212 Error Name::AppendTerminator(Message &aMessage)
213 {
214     uint8_t terminator = 0;
215 
216     return aMessage.Append(terminator);
217 }
218 
AppendPointerLabel(uint16_t aOffset,Message & aMessage)219 Error Name::AppendPointerLabel(uint16_t aOffset, Message &aMessage)
220 {
221     Error    error;
222     uint16_t value;
223 
224 #if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
225     if (!Instance::IsDnsNameCompressionEnabled())
226     {
227         // If "DNS name compression" mode is disabled, instead of
228         // appending the pointer label, read the name from the message
229         // and append it uncompressed. Note that the `aOffset` parameter
230         // in this method is given relative to the start of DNS header
231         // in `aMessage` (which `aMessage.GetOffset()` specifies).
232 
233         error = Name(aMessage, aOffset + aMessage.GetOffset()).AppendTo(aMessage);
234         ExitNow();
235     }
236 #endif
237 
238     // A pointer label takes the form of a two byte sequence as a
239     // `uint16_t` value. The first two bits are ones. This allows a
240     // pointer to be distinguished from a text label, since the text
241     // label must begin with two zero bits (note that labels are
242     // restricted to 63 octets or less). The next 14-bits specify
243     // an offset value relative to start of DNS header.
244 
245     OT_ASSERT(aOffset < kPointerLabelTypeUint16);
246 
247     value = HostSwap16(aOffset | kPointerLabelTypeUint16);
248 
249     ExitNow(error = aMessage.Append(value));
250 
251 exit:
252     return error;
253 }
254 
AppendName(const char * aName,Message & aMessage)255 Error Name::AppendName(const char *aName, Message &aMessage)
256 {
257     Error error;
258 
259     SuccessOrExit(error = AppendMultipleLabels(aName, aMessage));
260     error = AppendTerminator(aMessage);
261 
262 exit:
263     return error;
264 }
265 
ParseName(const Message & aMessage,uint16_t & aOffset)266 Error Name::ParseName(const Message &aMessage, uint16_t &aOffset)
267 {
268     Error         error;
269     LabelIterator iterator(aMessage, aOffset);
270 
271     while (true)
272     {
273         error = iterator.GetNextLabel();
274 
275         switch (error)
276         {
277         case kErrorNone:
278             break;
279 
280         case kErrorNotFound:
281             // We reached the end of name successfully.
282             aOffset = iterator.mNameEndOffset;
283             error   = kErrorNone;
284 
285             OT_FALL_THROUGH;
286 
287         default:
288             ExitNow();
289         }
290     }
291 
292 exit:
293     return error;
294 }
295 
ReadLabel(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t & aLabelLength)296 Error Name::ReadLabel(const Message &aMessage, uint16_t &aOffset, char *aLabelBuffer, uint8_t &aLabelLength)
297 {
298     Error         error;
299     LabelIterator iterator(aMessage, aOffset);
300 
301     SuccessOrExit(error = iterator.GetNextLabel());
302     SuccessOrExit(error = iterator.ReadLabel(aLabelBuffer, aLabelLength, /* aAllowDotCharInLabel */ true));
303     aOffset = iterator.mNextLabelOffset;
304 
305 exit:
306     return error;
307 }
308 
ReadName(const Message & aMessage,uint16_t & aOffset,char * aNameBuffer,uint16_t aNameBufferSize)309 Error Name::ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize)
310 {
311     Error         error;
312     LabelIterator iterator(aMessage, aOffset);
313     bool          firstLabel = true;
314     uint8_t       labelLength;
315 
316     while (true)
317     {
318         error = iterator.GetNextLabel();
319 
320         switch (error)
321         {
322         case kErrorNone:
323 
324             if (!firstLabel)
325             {
326                 *aNameBuffer++ = kLabelSeparatorChar;
327                 aNameBufferSize--;
328 
329                 // No need to check if we have reached end of the name buffer
330                 // here since `iterator.ReadLabel()` would verify it.
331             }
332 
333             labelLength = static_cast<uint8_t>(Min(static_cast<uint16_t>(kMaxLabelSize), aNameBufferSize));
334             SuccessOrExit(error = iterator.ReadLabel(aNameBuffer, labelLength, /* aAllowDotCharInLabel */ firstLabel));
335             aNameBuffer += labelLength;
336             aNameBufferSize -= labelLength;
337             firstLabel = false;
338             break;
339 
340         case kErrorNotFound:
341             // We reach the end of name successfully. Always add a terminating dot
342             // at the end.
343             *aNameBuffer++ = kLabelSeparatorChar;
344             aNameBufferSize--;
345             VerifyOrExit(aNameBufferSize >= sizeof(uint8_t), error = kErrorNoBufs);
346             *aNameBuffer = kNullChar;
347             aOffset      = iterator.mNameEndOffset;
348             error        = kErrorNone;
349 
350             OT_FALL_THROUGH;
351 
352         default:
353             ExitNow();
354         }
355     }
356 
357 exit:
358     return error;
359 }
360 
CompareLabel(const Message & aMessage,uint16_t & aOffset,const char * aLabel)361 Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel)
362 {
363     Error         error;
364     LabelIterator iterator(aMessage, aOffset);
365 
366     SuccessOrExit(error = iterator.GetNextLabel());
367     VerifyOrExit(iterator.CompareLabel(aLabel, kIsSingleLabel), error = kErrorNotFound);
368     aOffset = iterator.mNextLabelOffset;
369 
370 exit:
371     return error;
372 }
373 
CompareName(const Message & aMessage,uint16_t & aOffset,const char * aName)374 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
375 {
376     Error         error;
377     LabelIterator iterator(aMessage, aOffset);
378     bool          matches = true;
379 
380     if (*aName == kLabelSeparatorChar)
381     {
382         aName++;
383         VerifyOrExit(*aName == kNullChar, error = kErrorInvalidArgs);
384     }
385 
386     while (true)
387     {
388         error = iterator.GetNextLabel();
389 
390         switch (error)
391         {
392         case kErrorNone:
393             if (matches && !iterator.CompareLabel(aName, !kIsSingleLabel))
394             {
395                 matches = false;
396             }
397 
398             break;
399 
400         case kErrorNotFound:
401             // We reached the end of the name in `aMessage`. We check if
402             // all the previous labels matched so far, and we are also
403             // at the end of `aName` string (see null char), then we
404             // return `kErrorNone` indicating a successful comparison
405             // (full match). Otherwise we return `kErrorNotFound` to
406             // indicate failed comparison.
407 
408             if (matches && (*aName == kNullChar))
409             {
410                 error = kErrorNone;
411             }
412 
413             aOffset = iterator.mNameEndOffset;
414 
415             OT_FALL_THROUGH;
416 
417         default:
418             ExitNow();
419         }
420     }
421 
422 exit:
423     return error;
424 }
425 
CompareName(const Message & aMessage,uint16_t & aOffset,const Message & aMessage2,uint16_t aOffset2)426 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Message &aMessage2, uint16_t aOffset2)
427 {
428     Error         error;
429     LabelIterator iterator(aMessage, aOffset);
430     LabelIterator iterator2(aMessage2, aOffset2);
431     bool          matches = true;
432 
433     while (true)
434     {
435         error = iterator.GetNextLabel();
436 
437         switch (error)
438         {
439         case kErrorNone:
440             // If all the previous labels matched so far, then verify
441             // that we can get the next label on `iterator2` and that it
442             // matches the label from `iterator`.
443             if (matches && (iterator2.GetNextLabel() != kErrorNone || !iterator.CompareLabel(iterator2)))
444             {
445                 matches = false;
446             }
447 
448             break;
449 
450         case kErrorNotFound:
451             // We reached the end of the name in `aMessage`. We check
452             // that `iterator2` is also at its end, and if all previous
453             // labels matched we return `kErrorNone`.
454 
455             if (matches && (iterator2.GetNextLabel() == kErrorNotFound))
456             {
457                 error = kErrorNone;
458             }
459 
460             aOffset = iterator.mNameEndOffset;
461 
462             OT_FALL_THROUGH;
463 
464         default:
465             ExitNow();
466         }
467     }
468 
469 exit:
470     return error;
471 }
472 
CompareName(const Message & aMessage,uint16_t & aOffset,const Name & aName)473 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Name &aName)
474 {
475     return aName.IsFromCString()
476                ? CompareName(aMessage, aOffset, aName.mString)
477                : (aName.IsFromMessage() ? CompareName(aMessage, aOffset, *aName.mMessage, aName.mOffset)
478                                         : ParseName(aMessage, aOffset));
479 }
480 
GetNextLabel(void)481 Error Name::LabelIterator::GetNextLabel(void)
482 {
483     Error error;
484 
485     while (true)
486     {
487         uint8_t labelLength;
488         uint8_t labelType;
489 
490         SuccessOrExit(error = mMessage.Read(mNextLabelOffset, labelLength));
491 
492         labelType = labelLength & kLabelTypeMask;
493 
494         if (labelType == kTextLabelType)
495         {
496             if (labelLength == 0)
497             {
498                 // Zero label length indicates end of a name.
499 
500                 if (!IsEndOffsetSet())
501                 {
502                     mNameEndOffset = mNextLabelOffset + sizeof(uint8_t);
503                 }
504 
505                 ExitNow(error = kErrorNotFound);
506             }
507 
508             mLabelStartOffset = mNextLabelOffset + sizeof(uint8_t);
509             mLabelLength      = labelLength;
510             mNextLabelOffset  = mLabelStartOffset + labelLength;
511             ExitNow();
512         }
513         else if (labelType == kPointerLabelType)
514         {
515             // A pointer label takes the form of a two byte sequence as a
516             // `uint16_t` value. The first two bits are ones. The next 14 bits
517             // specify an offset value from the start of the DNS header.
518 
519             uint16_t pointerValue;
520             uint16_t nextLabelOffset;
521 
522             SuccessOrExit(error = mMessage.Read(mNextLabelOffset, pointerValue));
523 
524             if (!IsEndOffsetSet())
525             {
526                 mNameEndOffset = mNextLabelOffset + sizeof(uint16_t);
527             }
528 
529             // `mMessage.GetOffset()` must point to the start of the
530             // DNS header.
531             nextLabelOffset = mMessage.GetOffset() + (HostSwap16(pointerValue) & kPointerLabelOffsetMask);
532             VerifyOrExit(nextLabelOffset < mNextLabelOffset, error = kErrorParse);
533             mNextLabelOffset = nextLabelOffset;
534 
535             // Go back through the `while(true)` loop to get the next label.
536         }
537         else
538         {
539             ExitNow(error = kErrorParse);
540         }
541     }
542 
543 exit:
544     return error;
545 }
546 
ReadLabel(char * aLabelBuffer,uint8_t & aLabelLength,bool aAllowDotCharInLabel) const547 Error Name::LabelIterator::ReadLabel(char *aLabelBuffer, uint8_t &aLabelLength, bool aAllowDotCharInLabel) const
548 {
549     Error error;
550 
551     VerifyOrExit(mLabelLength < aLabelLength, error = kErrorNoBufs);
552 
553     SuccessOrExit(error = mMessage.Read(mLabelStartOffset, aLabelBuffer, mLabelLength));
554     aLabelBuffer[mLabelLength] = kNullChar;
555     aLabelLength               = mLabelLength;
556 
557     if (!aAllowDotCharInLabel)
558     {
559         VerifyOrExit(StringFind(aLabelBuffer, kLabelSeparatorChar) == nullptr, error = kErrorParse);
560     }
561 
562 exit:
563     return error;
564 }
565 
CaseInsensitiveMatch(uint8_t aFirst,uint8_t aSecond)566 bool Name::LabelIterator::CaseInsensitiveMatch(uint8_t aFirst, uint8_t aSecond)
567 {
568     return ToLowercase(static_cast<char>(aFirst)) == ToLowercase(static_cast<char>(aSecond));
569 }
570 
CompareLabel(const char * & aName,bool aIsSingleLabel) const571 bool Name::LabelIterator::CompareLabel(const char *&aName, bool aIsSingleLabel) const
572 {
573     // This method compares the current label in the iterator with the
574     // `aName` string. `aIsSingleLabel` indicates whether `aName` is a
575     // single label, or a sequence of labels separated by dot '.' char.
576     // If the label matches `aName`, then `aName` pointer is moved
577     // forward to the start of the next label (skipping over the `.`
578     // char). This method returns `true` when the labels match, `false`
579     // otherwise.
580 
581     bool matches = false;
582 
583     VerifyOrExit(StringLength(aName, mLabelLength) == mLabelLength);
584     matches = mMessage.CompareBytes(mLabelStartOffset, aName, mLabelLength, CaseInsensitiveMatch);
585 
586     VerifyOrExit(matches);
587 
588     aName += mLabelLength;
589 
590     // If `aName` is a single label, we should be also at the end of the
591     // `aName` string. Otherwise, we should see either null or dot '.'
592     // character (in case `aName` contains multiple labels).
593 
594     matches = (*aName == kNullChar);
595 
596     if (!aIsSingleLabel && (*aName == kLabelSeparatorChar))
597     {
598         matches = true;
599         aName++;
600     }
601 
602 exit:
603     return matches;
604 }
605 
CompareLabel(const LabelIterator & aOtherIterator) const606 bool Name::LabelIterator::CompareLabel(const LabelIterator &aOtherIterator) const
607 {
608     // This method compares the current label in the iterator with the
609     // label from another iterator.
610 
611     return (mLabelLength == aOtherIterator.mLabelLength) &&
612            mMessage.CompareBytes(mLabelStartOffset, aOtherIterator.mMessage, aOtherIterator.mLabelStartOffset,
613                                  mLabelLength, CaseInsensitiveMatch);
614 }
615 
AppendLabel(Message & aMessage) const616 Error Name::LabelIterator::AppendLabel(Message &aMessage) const
617 {
618     // This method reads and appends the current label in the iterator
619     // to `aMessage`.
620 
621     Error error;
622 
623     VerifyOrExit((0 < mLabelLength) && (mLabelLength <= kMaxLabelLength), error = kErrorInvalidArgs);
624     SuccessOrExit(error = aMessage.Append(mLabelLength));
625     error = aMessage.AppendBytesFromMessage(mMessage, mLabelStartOffset, mLabelLength);
626 
627 exit:
628     return error;
629 }
630 
ExtractLabels(const char * aName,const char * aSuffixName,char * aLabels,uint16_t aLabelsSize)631 Error Name::ExtractLabels(const char *aName, const char *aSuffixName, char *aLabels, uint16_t aLabelsSize)
632 {
633     Error       error        = kErrorParse;
634     uint16_t    nameLength   = StringLength(aName, kMaxNameSize);
635     uint16_t    suffixLength = StringLength(aSuffixName, kMaxNameSize);
636     const char *suffixStart;
637 
638     VerifyOrExit(nameLength < kMaxNameSize);
639     VerifyOrExit(suffixLength < kMaxNameSize);
640 
641     VerifyOrExit(nameLength > suffixLength);
642 
643     suffixStart = aName + nameLength - suffixLength;
644     VerifyOrExit(StringMatch(suffixStart, aSuffixName, kStringCaseInsensitiveMatch));
645     suffixStart--;
646     VerifyOrExit(*suffixStart == kLabelSeparatorChar);
647 
648     // Determine the labels length to copy
649     nameLength -= (suffixLength + 1);
650     VerifyOrExit(nameLength < aLabelsSize, error = kErrorNoBufs);
651 
652     memcpy(aLabels, aName, nameLength);
653     aLabels[nameLength] = kNullChar;
654     error               = kErrorNone;
655 
656 exit:
657     return error;
658 }
659 
IsSubDomainOf(const char * aName,const char * aDomain)660 bool Name::IsSubDomainOf(const char *aName, const char *aDomain)
661 {
662     bool     match             = false;
663     bool     nameEndsWithDot   = false;
664     bool     domainEndsWithDot = false;
665     uint16_t nameLength        = StringLength(aName, kMaxNameLength);
666     uint16_t domainLength      = StringLength(aDomain, kMaxNameLength);
667 
668     if (nameLength > 0 && aName[nameLength - 1] == kLabelSeparatorChar)
669     {
670         nameEndsWithDot = true;
671         --nameLength;
672     }
673 
674     if (domainLength > 0 && aDomain[domainLength - 1] == kLabelSeparatorChar)
675     {
676         domainEndsWithDot = true;
677         --domainLength;
678     }
679 
680     VerifyOrExit(nameLength >= domainLength);
681 
682     aName += nameLength - domainLength;
683 
684     if (nameLength > domainLength)
685     {
686         VerifyOrExit(aName[-1] == kLabelSeparatorChar);
687     }
688 
689     // This method allows either `aName` or `aDomain` to include or
690     // exclude the last `.` character. If both include it or if both
691     // do not, we do a full comparison using `StringMatch()`.
692     // Otherwise (i.e., when one includes and the other one does not)
693     // we use `StringStartWith()` to allow the extra `.` character.
694 
695     if (nameEndsWithDot == domainEndsWithDot)
696     {
697         match = StringMatch(aName, aDomain, kStringCaseInsensitiveMatch);
698     }
699     else if (nameEndsWithDot)
700     {
701         // `aName` ends with dot, but `aDomain` does not.
702         match = StringStartsWith(aName, aDomain, kStringCaseInsensitiveMatch);
703     }
704     else
705     {
706         // `aDomain` ends with dot, but `aName` does not.
707         match = StringStartsWith(aDomain, aName, kStringCaseInsensitiveMatch);
708     }
709 
710 exit:
711     return match;
712 }
713 
IsSameDomain(const char * aDomain1,const char * aDomain2)714 bool Name::IsSameDomain(const char *aDomain1, const char *aDomain2)
715 {
716     return IsSubDomainOf(aDomain1, aDomain2) && IsSubDomainOf(aDomain2, aDomain1);
717 }
718 
ParseRecords(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords)719 Error ResourceRecord::ParseRecords(const Message &aMessage, uint16_t &aOffset, uint16_t aNumRecords)
720 {
721     Error error = kErrorNone;
722 
723     while (aNumRecords > 0)
724     {
725         ResourceRecord record;
726 
727         SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
728         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
729         aOffset += static_cast<uint16_t>(record.GetSize());
730         aNumRecords--;
731     }
732 
733 exit:
734     return error;
735 }
736 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t & aNumRecords,const Name & aName)737 Error ResourceRecord::FindRecord(const Message &aMessage, uint16_t &aOffset, uint16_t &aNumRecords, const Name &aName)
738 {
739     Error error;
740 
741     while (aNumRecords > 0)
742     {
743         bool           matches = true;
744         ResourceRecord record;
745 
746         error = Name::CompareName(aMessage, aOffset, aName);
747 
748         switch (error)
749         {
750         case kErrorNone:
751             break;
752         case kErrorNotFound:
753             matches = false;
754             break;
755         default:
756             ExitNow();
757         }
758 
759         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
760         aNumRecords--;
761         VerifyOrExit(!matches);
762         aOffset += static_cast<uint16_t>(record.GetSize());
763     }
764 
765     error = kErrorNotFound;
766 
767 exit:
768     return error;
769 }
770 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords,uint16_t aIndex,const Name & aName,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)771 Error ResourceRecord::FindRecord(const Message  &aMessage,
772                                  uint16_t       &aOffset,
773                                  uint16_t        aNumRecords,
774                                  uint16_t        aIndex,
775                                  const Name     &aName,
776                                  uint16_t        aType,
777                                  ResourceRecord &aRecord,
778                                  uint16_t        aMinRecordSize)
779 {
780     // This static method searches in `aMessage` starting from `aOffset`
781     // up to maximum of `aNumRecords`, for the `(aIndex+1)`th
782     // occurrence of a resource record of type `aType` with record name
783     // matching `aName`. It also verifies that the record size is larger
784     // than `aMinRecordSize`. If found, `aMinRecordSize` bytes from the
785     // record are read and copied into `aRecord`. In this case `aOffset`
786     // is updated to point to the last record byte read from the message
787     // (so that the caller can read any remaining fields in the record
788     // data).
789 
790     Error    error;
791     uint16_t offset = aOffset;
792     uint16_t recordOffset;
793 
794     while (aNumRecords > 0)
795     {
796         SuccessOrExit(error = FindRecord(aMessage, offset, aNumRecords, aName));
797 
798         // Save the offset to start of `ResourceRecord` fields.
799         recordOffset = offset;
800 
801         error = ReadRecord(aMessage, offset, aType, aRecord, aMinRecordSize);
802 
803         if (error == kErrorNotFound)
804         {
805             // `ReadRecord()` already updates the `offset` to skip
806             // over a non-matching record.
807             continue;
808         }
809 
810         SuccessOrExit(error);
811 
812         if (aIndex == 0)
813         {
814             aOffset = offset;
815             ExitNow();
816         }
817 
818         aIndex--;
819 
820         // Skip over the record.
821         offset = static_cast<uint16_t>(recordOffset + aRecord.GetSize());
822     }
823 
824     error = kErrorNotFound;
825 
826 exit:
827     return error;
828 }
829 
ReadRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)830 Error ResourceRecord::ReadRecord(const Message  &aMessage,
831                                  uint16_t       &aOffset,
832                                  uint16_t        aType,
833                                  ResourceRecord &aRecord,
834                                  uint16_t        aMinRecordSize)
835 {
836     // This static method tries to read a matching resource record of a
837     // given type and a minimum record size from a message. The `aType`
838     // value of `kTypeAny` matches any type.  If the record in the
839     // message does not match, it skips over the record. Please see
840     // `ReadRecord<RecordType>()` for more details.
841 
842     Error          error;
843     ResourceRecord record;
844 
845     SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
846 
847     if (((aType == kTypeAny) || (record.GetType() == aType)) && (record.GetSize() >= aMinRecordSize))
848     {
849         IgnoreError(aMessage.Read(aOffset, &aRecord, aMinRecordSize));
850         aOffset += aMinRecordSize;
851     }
852     else
853     {
854         // Skip over the entire record.
855         aOffset += static_cast<uint16_t>(record.GetSize());
856         error = kErrorNotFound;
857     }
858 
859 exit:
860     return error;
861 }
862 
ReadName(const Message & aMessage,uint16_t & aOffset,uint16_t aStartOffset,char * aNameBuffer,uint16_t aNameBufferSize,bool aSkipRecord) const863 Error ResourceRecord::ReadName(const Message &aMessage,
864                                uint16_t      &aOffset,
865                                uint16_t       aStartOffset,
866                                char          *aNameBuffer,
867                                uint16_t       aNameBufferSize,
868                                bool           aSkipRecord) const
869 {
870     // This protected method parses and reads a name field in a record
871     // from a message. It is intended only for sub-classes of
872     // `ResourceRecord`.
873     //
874     // On input `aOffset` gives the offset in `aMessage` to the start of
875     // name field. `aStartOffset` gives the offset to the start of the
876     // `ResourceRecord`. `aSkipRecord` indicates whether to skip over
877     // the entire resource record or just the read name. On exit, when
878     // successfully read, `aOffset` is updated to either point after the
879     // end of record or after the the name field.
880     //
881     // When read successfully, this method returns `kErrorNone`. On a
882     // parse error (invalid format) returns `kErrorParse`. If the
883     // name does not fit in the given name buffer it returns
884     // `kErrorNoBufs`
885 
886     Error error = kErrorNone;
887 
888     SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
889     VerifyOrExit(aOffset <= aStartOffset + GetSize(), error = kErrorParse);
890 
891     VerifyOrExit(aSkipRecord);
892     aOffset = aStartOffset;
893     error   = SkipRecord(aMessage, aOffset);
894 
895 exit:
896     return error;
897 }
898 
SkipRecord(const Message & aMessage,uint16_t & aOffset) const899 Error ResourceRecord::SkipRecord(const Message &aMessage, uint16_t &aOffset) const
900 {
901     // This protected method parses and skips over a resource record
902     // in a message.
903     //
904     // On input `aOffset` gives the offset in `aMessage` to the start of
905     // the `ResourceRecord`. On exit, when successfully parsed, `aOffset`
906     // is updated to point to byte after the entire record.
907 
908     Error error;
909 
910     SuccessOrExit(error = CheckRecord(aMessage, aOffset));
911     aOffset += static_cast<uint16_t>(GetSize());
912 
913 exit:
914     return error;
915 }
916 
CheckRecord(const Message & aMessage,uint16_t aOffset) const917 Error ResourceRecord::CheckRecord(const Message &aMessage, uint16_t aOffset) const
918 {
919     // This method checks that the entire record (including record data)
920     // is present in `aMessage` at `aOffset` (pointing to the start of
921     // the `ResourceRecord` in `aMessage`).
922 
923     return (aOffset + GetSize() <= aMessage.GetLength()) ? kErrorNone : kErrorParse;
924 }
925 
ReadFrom(const Message & aMessage,uint16_t aOffset)926 Error ResourceRecord::ReadFrom(const Message &aMessage, uint16_t aOffset)
927 {
928     // This method reads the `ResourceRecord` from `aMessage` at
929     // `aOffset`. It verifies that the entire record (including record
930     // data) is present in the message.
931 
932     Error error;
933 
934     SuccessOrExit(error = aMessage.Read(aOffset, *this));
935     error = CheckRecord(aMessage, aOffset);
936 
937 exit:
938     return error;
939 }
940 
Init(const uint8_t * aTxtData,uint16_t aTxtDataLength)941 void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength)
942 {
943     SetTxtData(aTxtData);
944     SetTxtDataLength(aTxtDataLength);
945     SetTxtDataPosition(0);
946 }
947 
GetNextEntry(TxtEntry & aEntry)948 Error TxtEntry::Iterator::GetNextEntry(TxtEntry &aEntry)
949 {
950     Error       error = kErrorNone;
951     uint8_t     length;
952     uint8_t     index;
953     const char *cur;
954     char       *keyBuffer = GetKeyBuffer();
955 
956     static_assert(sizeof(mChar) == TxtEntry::kMaxKeyLength + 1, "KeyBuffer cannot fit the max key length");
957 
958     VerifyOrExit(GetTxtData() != nullptr, error = kErrorParse);
959 
960     aEntry.mKey = keyBuffer;
961 
962     while ((cur = GetTxtData() + GetTxtDataPosition()) < GetTxtDataEnd())
963     {
964         length = static_cast<uint8_t>(*cur);
965 
966         cur++;
967         VerifyOrExit(cur + length <= GetTxtDataEnd(), error = kErrorParse);
968         IncreaseTxtDataPosition(sizeof(uint8_t) + length);
969 
970         // Silently skip over an empty string or if the string starts with
971         // a `=` character (i.e., missing key) - RFC 6763 - section 6.4.
972 
973         if ((length == 0) || (cur[0] == kKeyValueSeparator))
974         {
975             continue;
976         }
977 
978         for (index = 0; index < length; index++)
979         {
980             if (cur[index] == kKeyValueSeparator)
981             {
982                 keyBuffer[index++]  = kNullChar; // Increment index to skip over `=`.
983                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(&cur[index]);
984                 aEntry.mValueLength = length - index;
985                 ExitNow();
986             }
987 
988             if (index >= kMaxKeyLength)
989             {
990                 // The key is larger than recommended max key length.
991                 // In this case, we return the full encoded string in
992                 // `mValue` and `mValueLength` and set `mKey` to
993                 // `nullptr`.
994 
995                 aEntry.mKey         = nullptr;
996                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(cur);
997                 aEntry.mValueLength = length;
998                 ExitNow();
999             }
1000 
1001             keyBuffer[index] = cur[index];
1002         }
1003 
1004         // If we reach the end of the string without finding `=` then
1005         // it is a boolean key attribute (encoded as "key").
1006 
1007         keyBuffer[index]    = kNullChar;
1008         aEntry.mValue       = nullptr;
1009         aEntry.mValueLength = 0;
1010         ExitNow();
1011     }
1012 
1013     error = kErrorNotFound;
1014 
1015 exit:
1016     return error;
1017 }
1018 
AppendTo(Message & aMessage) const1019 Error TxtEntry::AppendTo(Message &aMessage) const
1020 {
1021     Appender appender(aMessage);
1022 
1023     return AppendTo(appender);
1024 }
1025 
AppendTo(Appender & aAppender) const1026 Error TxtEntry::AppendTo(Appender &aAppender) const
1027 {
1028     Error    error = kErrorNone;
1029     uint16_t keyLength;
1030     char     separator = kKeyValueSeparator;
1031 
1032     if (mKey == nullptr)
1033     {
1034         VerifyOrExit((mValue != nullptr) && (mValueLength != 0));
1035         error = aAppender.AppendBytes(mValue, mValueLength);
1036         ExitNow();
1037     }
1038 
1039     keyLength = StringLength(mKey, static_cast<uint16_t>(kMaxKeyValueEncodedSize) + 1);
1040 
1041     VerifyOrExit(kMinKeyLength <= keyLength, error = kErrorInvalidArgs);
1042 
1043     if (mValue == nullptr)
1044     {
1045         // Treat as a boolean attribute and encoded as "key" (with no `=`).
1046         SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength)));
1047         error = aAppender.AppendBytes(mKey, keyLength);
1048         ExitNow();
1049     }
1050 
1051     // Treat as key/value and encode as "key=value", value may be empty.
1052 
1053     VerifyOrExit(mValueLength + keyLength + sizeof(char) <= kMaxKeyValueEncodedSize, error = kErrorInvalidArgs);
1054 
1055     SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength + mValueLength + sizeof(char))));
1056     SuccessOrExit(error = aAppender.AppendBytes(mKey, keyLength));
1057     SuccessOrExit(error = aAppender.Append(separator));
1058     error = aAppender.AppendBytes(mValue, mValueLength);
1059 
1060 exit:
1061     return error;
1062 }
1063 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Message & aMessage)1064 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Message &aMessage)
1065 {
1066     Appender appender(aMessage);
1067 
1068     return AppendEntries(aEntries, aNumEntries, appender);
1069 }
1070 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,MutableData<kWithUint16Length> & aData)1071 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, MutableData<kWithUint16Length> &aData)
1072 {
1073     Error    error;
1074     Appender appender(aData.GetBytes(), aData.GetLength());
1075 
1076     SuccessOrExit(error = AppendEntries(aEntries, aNumEntries, appender));
1077     appender.GetAsData(aData);
1078 
1079 exit:
1080     return error;
1081 }
1082 
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Appender & aAppender)1083 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Appender &aAppender)
1084 {
1085     Error error = kErrorNone;
1086 
1087     for (uint16_t index = 0; index < aNumEntries; index++)
1088     {
1089         SuccessOrExit(error = aEntries[index].AppendTo(aAppender));
1090     }
1091 
1092     if (aAppender.GetAppendedLength() == 0)
1093     {
1094         error = aAppender.Append<uint8_t>(0);
1095     }
1096 
1097 exit:
1098     return error;
1099 }
1100 
IsValid(void) const1101 bool AaaaRecord::IsValid(void) const
1102 {
1103     return GetType() == Dns::ResourceRecord::kTypeAaaa && GetSize() == sizeof(*this);
1104 }
1105 
IsValid(void) const1106 bool KeyRecord::IsValid(void) const { return GetType() == Dns::ResourceRecord::kTypeKey; }
1107 
1108 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Init(void)1109 void Ecdsa256KeyRecord::Init(void)
1110 {
1111     KeyRecord::Init();
1112     SetAlgorithm(kAlgorithmEcdsaP256Sha256);
1113 }
1114 
IsValid(void) const1115 bool Ecdsa256KeyRecord::IsValid(void) const
1116 {
1117     return KeyRecord::IsValid() && GetLength() == sizeof(*this) - sizeof(ResourceRecord) &&
1118            GetAlgorithm() == kAlgorithmEcdsaP256Sha256;
1119 }
1120 #endif
1121 
IsValid(void) const1122 bool SigRecord::IsValid(void) const
1123 {
1124     return GetType() == Dns::ResourceRecord::kTypeSig && GetLength() >= sizeof(*this) - sizeof(ResourceRecord);
1125 }
1126 
InitAsShortVariant(uint32_t aLeaseInterval)1127 void LeaseOption::InitAsShortVariant(uint32_t aLeaseInterval)
1128 {
1129     SetOptionCode(kUpdateLease);
1130     SetOptionLength(kShortLength);
1131     SetLeaseInterval(aLeaseInterval);
1132 }
1133 
InitAsLongVariant(uint32_t aLeaseInterval,uint32_t aKeyLeaseInterval)1134 void LeaseOption::InitAsLongVariant(uint32_t aLeaseInterval, uint32_t aKeyLeaseInterval)
1135 {
1136     SetOptionCode(kUpdateLease);
1137     SetOptionLength(kLongLength);
1138     SetLeaseInterval(aLeaseInterval);
1139     SetKeyLeaseInterval(aKeyLeaseInterval);
1140 }
1141 
IsValid(void) const1142 bool LeaseOption::IsValid(void) const
1143 {
1144     bool isValid = false;
1145 
1146     VerifyOrExit((GetOptionLength() == kShortLength) || (GetOptionLength() >= kLongLength));
1147     isValid = (GetLeaseInterval() <= GetKeyLeaseInterval());
1148 
1149 exit:
1150     return isValid;
1151 }
1152 
ReadFrom(const Message & aMessage,uint16_t aOffset,uint16_t aLength)1153 Error LeaseOption::ReadFrom(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
1154 {
1155     Error    error = kErrorNone;
1156     uint16_t endOffset;
1157 
1158     VerifyOrExit(static_cast<uint32_t>(aOffset) + aLength <= aMessage.GetLength(), error = kErrorParse);
1159 
1160     endOffset = aOffset + aLength;
1161 
1162     while (aOffset < endOffset)
1163     {
1164         uint16_t size;
1165 
1166         SuccessOrExit(error = aMessage.Read(aOffset, this, sizeof(Option)));
1167 
1168         VerifyOrExit(aOffset + GetSize() <= endOffset, error = kErrorParse);
1169 
1170         size = static_cast<uint16_t>(GetSize());
1171 
1172         if (GetOptionCode() == kUpdateLease)
1173         {
1174             VerifyOrExit(GetOptionLength() >= kShortLength, error = kErrorParse);
1175 
1176             IgnoreError(aMessage.Read(aOffset, this, Min(size, static_cast<uint16_t>(sizeof(LeaseOption)))));
1177             VerifyOrExit(IsValid(), error = kErrorParse);
1178 
1179             ExitNow();
1180         }
1181 
1182         aOffset += size;
1183     }
1184 
1185     error = kErrorNotFound;
1186 
1187 exit:
1188     return error;
1189 }
1190 
ReadPtrName(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const1191 Error PtrRecord::ReadPtrName(const Message &aMessage,
1192                              uint16_t      &aOffset,
1193                              char          *aLabelBuffer,
1194                              uint8_t        aLabelBufferSize,
1195                              char          *aNameBuffer,
1196                              uint16_t       aNameBufferSize) const
1197 {
1198     Error    error       = kErrorNone;
1199     uint16_t startOffset = aOffset - sizeof(PtrRecord); // start of `PtrRecord`.
1200 
1201     // Verify that the name is within the record data length.
1202     SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
1203     VerifyOrExit(aOffset <= startOffset + GetSize(), error = kErrorParse);
1204 
1205     aOffset = startOffset + sizeof(PtrRecord);
1206     SuccessOrExit(error = Name::ReadLabel(aMessage, aOffset, aLabelBuffer, aLabelBufferSize));
1207 
1208     if (aNameBuffer != nullptr)
1209     {
1210         SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
1211     }
1212 
1213     aOffset = startOffset;
1214     error   = SkipRecord(aMessage, aOffset);
1215 
1216 exit:
1217     return error;
1218 }
1219 
ReadTxtData(const Message & aMessage,uint16_t & aOffset,uint8_t * aTxtBuffer,uint16_t & aTxtBufferSize) const1220 Error TxtRecord::ReadTxtData(const Message &aMessage,
1221                              uint16_t      &aOffset,
1222                              uint8_t       *aTxtBuffer,
1223                              uint16_t      &aTxtBufferSize) const
1224 {
1225     Error error = kErrorNone;
1226 
1227     SuccessOrExit(error = aMessage.Read(aOffset, aTxtBuffer, Min(GetLength(), aTxtBufferSize)));
1228     aOffset += GetLength();
1229 
1230     VerifyOrExit(GetLength() <= aTxtBufferSize, error = kErrorNoBufs);
1231     aTxtBufferSize = GetLength();
1232     VerifyOrExit(VerifyTxtData(aTxtBuffer, aTxtBufferSize, /* aAllowEmpty */ true), error = kErrorParse);
1233 
1234 exit:
1235     return error;
1236 }
1237 
VerifyTxtData(const uint8_t * aTxtData,uint16_t aTxtLength,bool aAllowEmpty)1238 bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength, bool aAllowEmpty)
1239 {
1240     bool    valid          = false;
1241     uint8_t curEntryLength = 0;
1242 
1243     // Per RFC 1035, TXT-DATA MUST have one or more <character-string>s.
1244     VerifyOrExit(aAllowEmpty || aTxtLength > 0);
1245 
1246     for (uint16_t i = 0; i < aTxtLength; ++i)
1247     {
1248         if (curEntryLength == 0)
1249         {
1250             curEntryLength = aTxtData[i];
1251         }
1252         else
1253         {
1254             --curEntryLength;
1255         }
1256     }
1257 
1258     valid = (curEntryLength == 0);
1259 
1260 exit:
1261     return valid;
1262 }
1263 
1264 } // namespace Dns
1265 } // namespace ot
1266