1 /*
2  *  Copyright (c) 2020, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements generating and processing of DNS headers and helper functions/methods.
32  */
33 
34 #include "dns_types.hpp"
35 
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/instance.hpp"
39 #include "common/logging.hpp"
40 #include "common/random.hpp"
41 #include "common/string.hpp"
42 
43 namespace ot {
44 namespace Dns {
45 
46 using ot::Encoding::BigEndian::HostSwap16;
47 
SetRandomMessageId(void)48 Error Header::SetRandomMessageId(void)
49 {
50     return Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&mMessageId), sizeof(mMessageId));
51 }
52 
ResponseCodeToError(Response aResponse)53 Error Header::ResponseCodeToError(Response aResponse)
54 {
55     Error error = kErrorFailed;
56 
57     switch (aResponse)
58     {
59     case kResponseSuccess:
60         error = kErrorNone;
61         break;
62 
63     case kResponseFormatError:   // Server unable to interpret request due to format error.
64     case kResponseBadName:       // Bad name.
65     case kResponseBadTruncation: // Bad truncation.
66     case kResponseNotZone:       // A name is not in the zone.
67         error = kErrorParse;
68         break;
69 
70     case kResponseServerFailure: // Server encountered an internal failure.
71         error = kErrorFailed;
72         break;
73 
74     case kResponseNameError:       // Name that ought to exist, does not exists.
75     case kResponseRecordNotExists: // Some RRset that ought to exist, does not exist.
76         error = kErrorNotFound;
77         break;
78 
79     case kResponseNotImplemented: // Server does not support the query type (OpCode).
80         error = kErrorNotImplemented;
81         break;
82 
83     case kResponseBadAlg: // Bad algorithm.
84         error = kErrorNotCapable;
85         break;
86 
87     case kResponseNameExists:   // Some name that ought not to exist, does exist.
88     case kResponseRecordExists: // Some RRset that ought not to exist, does exist.
89         error = kErrorDuplicated;
90         break;
91 
92     case kResponseRefused: // Server refused to perform operation for policy or security reasons.
93     case kResponseNotAuth: // Service is not authoritative for zone.
94         error = kErrorSecurity;
95         break;
96 
97     default:
98         break;
99     }
100 
101     return error;
102 }
103 
AppendTo(Message & aMessage) const104 Error Name::AppendTo(Message &aMessage) const
105 {
106     Error error;
107 
108     if (IsEmpty())
109     {
110         error = AppendTerminator(aMessage);
111     }
112     else if (IsFromCString())
113     {
114         error = AppendName(GetAsCString(), aMessage);
115     }
116     else
117     {
118         // Name is from a message. Read labels one by one from
119         // `mMessage` and and append each to the `aMessage`.
120 
121         LabelIterator iterator(*mMessage, mOffset);
122 
123         while (true)
124         {
125             error = iterator.GetNextLabel();
126 
127             switch (error)
128             {
129             case kErrorNone:
130                 SuccessOrExit(error = iterator.AppendLabel(aMessage));
131                 break;
132 
133             case kErrorNotFound:
134                 // We reached the end of name successfully.
135                 error = AppendTerminator(aMessage);
136 
137                 OT_FALL_THROUGH;
138 
139             default:
140                 ExitNow();
141             }
142         }
143     }
144 
145 exit:
146     return error;
147 }
148 
AppendLabel(const char * aLabel,Message & aMessage)149 Error Name::AppendLabel(const char *aLabel, Message &aMessage)
150 {
151     return AppendLabel(aLabel, static_cast<uint8_t>(StringLength(aLabel, kMaxLabelSize)), aMessage);
152 }
153 
AppendLabel(const char * aLabel,uint8_t aLength,Message & aMessage)154 Error Name::AppendLabel(const char *aLabel, uint8_t aLength, Message &aMessage)
155 {
156     Error error = kErrorNone;
157 
158     VerifyOrExit((0 < aLength) && (aLength <= kMaxLabelLength), error = kErrorInvalidArgs);
159 
160     SuccessOrExit(error = aMessage.Append(aLength));
161     error = aMessage.AppendBytes(aLabel, aLength);
162 
163 exit:
164     return error;
165 }
166 
AppendMultipleLabels(const char * aLabels,Message & aMessage)167 Error Name::AppendMultipleLabels(const char *aLabels, Message &aMessage)
168 {
169     return AppendMultipleLabels(aLabels, kMaxNameLength, aMessage);
170 }
171 
AppendMultipleLabels(const char * aLabels,uint8_t aLength,Message & aMessage)172 Error Name::AppendMultipleLabels(const char *aLabels, uint8_t aLength, Message &aMessage)
173 {
174     Error    error           = kErrorNone;
175     uint16_t index           = 0;
176     uint16_t labelStartIndex = 0;
177     char     ch;
178 
179     VerifyOrExit(aLabels != nullptr);
180 
181     do
182     {
183         ch = index < aLength ? aLabels[index] : static_cast<char>(kNullChar);
184 
185         if ((ch == kNullChar) || (ch == kLabelSeperatorChar))
186         {
187             uint8_t labelLength = static_cast<uint8_t>(index - labelStartIndex);
188 
189             if (labelLength == 0)
190             {
191                 // Empty label (e.g., consecutive dots) is invalid, but we
192                 // allow for two cases: (1) where `aLabels` ends with a dot
193                 // (`labelLength` is zero but we are at end of `aLabels` string
194                 // and `ch` is null char. (2) if `aLabels` is just "." (we
195                 // see a dot at index 0, and index 1 is null char).
196 
197                 error =
198                     ((ch == kNullChar) || ((index == 0) && (aLabels[1] == kNullChar))) ? kErrorNone : kErrorInvalidArgs;
199                 ExitNow();
200             }
201 
202             VerifyOrExit(index + 1 < kMaxEncodedLength, error = kErrorInvalidArgs);
203             SuccessOrExit(error = AppendLabel(&aLabels[labelStartIndex], labelLength, aMessage));
204 
205             labelStartIndex = index + 1;
206         }
207 
208         index++;
209 
210     } while (ch != kNullChar);
211 
212 exit:
213     return error;
214 }
215 
AppendTerminator(Message & aMessage)216 Error Name::AppendTerminator(Message &aMessage)
217 {
218     uint8_t terminator = 0;
219 
220     return aMessage.Append(terminator);
221 }
222 
AppendPointerLabel(uint16_t aOffset,Message & aMessage)223 Error Name::AppendPointerLabel(uint16_t aOffset, Message &aMessage)
224 {
225     Error    error;
226     uint16_t value;
227 
228 #if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
229     if (!Instance::IsDnsNameCompressionEnabled())
230     {
231         // If "DNS name compression" mode is disabled, instead of
232         // appending the pointer label, read the name from the message
233         // and append it uncompressed. Note that the `aOffset` parameter
234         // in this method is given relative to the start of DNS header
235         // in `aMessage` (which `aMessage.GetOffset()` specifies).
236 
237         error = Name(aMessage, aOffset + aMessage.GetOffset()).AppendTo(aMessage);
238         ExitNow();
239     }
240 #endif
241 
242     // A pointer label takes the form of a two byte sequence as a
243     // `uint16_t` value. The first two bits are ones. This allows a
244     // pointer to be distinguished from a text label, since the text
245     // label must begin with two zero bits (note that labels are
246     // restricted to 63 octets or less). The next 14-bits specify
247     // an offset value relative to start of DNS header.
248 
249     OT_ASSERT(aOffset < kPointerLabelTypeUint16);
250 
251     value = HostSwap16(aOffset | kPointerLabelTypeUint16);
252 
253     ExitNow(error = aMessage.Append(value));
254 
255 exit:
256     return error;
257 }
258 
AppendName(const char * aName,Message & aMessage)259 Error Name::AppendName(const char *aName, Message &aMessage)
260 {
261     Error error;
262 
263     SuccessOrExit(error = AppendMultipleLabels(aName, aMessage));
264     error = AppendTerminator(aMessage);
265 
266 exit:
267     return error;
268 }
269 
ParseName(const Message & aMessage,uint16_t & aOffset)270 Error Name::ParseName(const Message &aMessage, uint16_t &aOffset)
271 {
272     Error         error;
273     LabelIterator iterator(aMessage, aOffset);
274 
275     while (true)
276     {
277         error = iterator.GetNextLabel();
278 
279         switch (error)
280         {
281         case kErrorNone:
282             break;
283 
284         case kErrorNotFound:
285             // We reached the end of name successfully.
286             aOffset = iterator.mNameEndOffset;
287             error   = kErrorNone;
288 
289             OT_FALL_THROUGH;
290 
291         default:
292             ExitNow();
293         }
294     }
295 
296 exit:
297     return error;
298 }
299 
ReadLabel(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t & aLabelLength)300 Error Name::ReadLabel(const Message &aMessage, uint16_t &aOffset, char *aLabelBuffer, uint8_t &aLabelLength)
301 {
302     Error         error;
303     LabelIterator iterator(aMessage, aOffset);
304 
305     SuccessOrExit(error = iterator.GetNextLabel());
306     SuccessOrExit(error = iterator.ReadLabel(aLabelBuffer, aLabelLength, /* aAllowDotCharInLabel */ true));
307     aOffset = iterator.mNextLabelOffset;
308 
309 exit:
310     return error;
311 }
312 
ReadName(const Message & aMessage,uint16_t & aOffset,char * aNameBuffer,uint16_t aNameBufferSize)313 Error Name::ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize)
314 {
315     Error         error;
316     LabelIterator iterator(aMessage, aOffset);
317     bool          firstLabel = true;
318     uint8_t       labelLength;
319 
320     while (true)
321     {
322         error = iterator.GetNextLabel();
323 
324         switch (error)
325         {
326         case kErrorNone:
327 
328             if (!firstLabel)
329             {
330                 *aNameBuffer++ = kLabelSeperatorChar;
331                 aNameBufferSize--;
332 
333                 // No need to check if we have reached end of the name buffer
334                 // here since `iterator.ReadLabel()` would verify it.
335             }
336 
337             labelLength = static_cast<uint8_t>(OT_MIN(static_cast<uint8_t>(kMaxLabelSize), aNameBufferSize));
338             SuccessOrExit(error = iterator.ReadLabel(aNameBuffer, labelLength, /* aAllowDotCharInLabel */ false));
339             aNameBuffer += labelLength;
340             aNameBufferSize -= labelLength;
341             firstLabel = false;
342             break;
343 
344         case kErrorNotFound:
345             // We reach the end of name successfully. Always add a terminating dot
346             // at the end.
347             *aNameBuffer++ = kLabelSeperatorChar;
348             aNameBufferSize--;
349             VerifyOrExit(aNameBufferSize >= sizeof(uint8_t), error = kErrorNoBufs);
350             *aNameBuffer = kNullChar;
351             aOffset      = iterator.mNameEndOffset;
352             error        = kErrorNone;
353 
354             OT_FALL_THROUGH;
355 
356         default:
357             ExitNow();
358         }
359     }
360 
361 exit:
362     return error;
363 }
364 
CompareLabel(const Message & aMessage,uint16_t & aOffset,const char * aLabel)365 Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel)
366 {
367     Error         error;
368     LabelIterator iterator(aMessage, aOffset);
369 
370     SuccessOrExit(error = iterator.GetNextLabel());
371     VerifyOrExit(iterator.CompareLabel(aLabel, /* aIsSingleLabel */ true), error = kErrorNotFound);
372     aOffset = iterator.mNextLabelOffset;
373 
374 exit:
375     return error;
376 }
377 
CompareName(const Message & aMessage,uint16_t & aOffset,const char * aName)378 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
379 {
380     Error         error;
381     LabelIterator iterator(aMessage, aOffset);
382     bool          matches = true;
383 
384     if (*aName == kLabelSeperatorChar)
385     {
386         aName++;
387         VerifyOrExit(*aName == kNullChar, error = kErrorInvalidArgs);
388     }
389 
390     while (true)
391     {
392         error = iterator.GetNextLabel();
393 
394         switch (error)
395         {
396         case kErrorNone:
397             if (matches && !iterator.CompareLabel(aName, /* aIsSingleLabel */ false))
398             {
399                 matches = false;
400             }
401 
402             break;
403 
404         case kErrorNotFound:
405             // We reached the end of the name in `aMessage`. We check if
406             // all the previous labels matched so far, and we are also
407             // at the end of `aName` string (see null char), then we
408             // return `kErrorNone` indicating a successful comparison
409             // (full match). Otherwise we return `kErrorNotFound` to
410             // indicate failed comparison.
411 
412             if (matches && (*aName == kNullChar))
413             {
414                 error = kErrorNone;
415             }
416 
417             aOffset = iterator.mNameEndOffset;
418 
419             OT_FALL_THROUGH;
420 
421         default:
422             ExitNow();
423         }
424     }
425 
426 exit:
427     return error;
428 }
429 
CompareName(const Message & aMessage,uint16_t & aOffset,const Message & aMessage2,uint16_t aOffset2)430 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Message &aMessage2, uint16_t aOffset2)
431 {
432     Error         error;
433     LabelIterator iterator(aMessage, aOffset);
434     LabelIterator iterator2(aMessage2, aOffset2);
435     bool          matches = true;
436 
437     while (true)
438     {
439         error = iterator.GetNextLabel();
440 
441         switch (error)
442         {
443         case kErrorNone:
444             // If all the previous labels matched so far, then verify
445             // that we can get the next label on `iterator2` and that it
446             // matches the label from `iterator`.
447             if (matches && (iterator2.GetNextLabel() != kErrorNone || !iterator.CompareLabel(iterator2)))
448             {
449                 matches = false;
450             }
451 
452             break;
453 
454         case kErrorNotFound:
455             // We reached the end of the name in `aMessage`. We check
456             // that `iterator2` is also at its end, and if all previous
457             // labels matched we return `kErrorNone`.
458 
459             if (matches && (iterator2.GetNextLabel() == kErrorNotFound))
460             {
461                 error = kErrorNone;
462             }
463 
464             aOffset = iterator.mNameEndOffset;
465 
466             OT_FALL_THROUGH;
467 
468         default:
469             ExitNow();
470         }
471     }
472 
473 exit:
474     return error;
475 }
476 
CompareName(const Message & aMessage,uint16_t & aOffset,const Name & aName)477 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Name &aName)
478 {
479     return aName.IsFromCString()
480                ? CompareName(aMessage, aOffset, aName.mString)
481                : (aName.IsFromMessage() ? CompareName(aMessage, aOffset, *aName.mMessage, aName.mOffset)
482                                         : ParseName(aMessage, aOffset));
483 }
484 
GetNextLabel(void)485 Error Name::LabelIterator::GetNextLabel(void)
486 {
487     Error error;
488 
489     while (true)
490     {
491         uint8_t labelLength;
492         uint8_t labelType;
493 
494         SuccessOrExit(error = mMessage.Read(mNextLabelOffset, labelLength));
495 
496         labelType = labelLength & kLabelTypeMask;
497 
498         if (labelType == kTextLabelType)
499         {
500             if (labelLength == 0)
501             {
502                 // Zero label length indicates end of a name.
503 
504                 if (!IsEndOffsetSet())
505                 {
506                     mNameEndOffset = mNextLabelOffset + sizeof(uint8_t);
507                 }
508 
509                 ExitNow(error = kErrorNotFound);
510             }
511 
512             mLabelStartOffset = mNextLabelOffset + sizeof(uint8_t);
513             mLabelLength      = labelLength;
514             mNextLabelOffset  = mLabelStartOffset + labelLength;
515             ExitNow();
516         }
517         else if (labelType == kPointerLabelType)
518         {
519             // A pointer label takes the form of a two byte sequence as a
520             // `uint16_t` value. The first two bits are ones. The next 14 bits
521             // specify an offset value from the start of the DNS header.
522 
523             uint16_t pointerValue;
524 
525             SuccessOrExit(error = mMessage.Read(mNextLabelOffset, pointerValue));
526 
527             if (!IsEndOffsetSet())
528             {
529                 mNameEndOffset = mNextLabelOffset + sizeof(uint16_t);
530             }
531 
532             // `mMessage.GetOffset()` must point to the start of the
533             // DNS header.
534             mNextLabelOffset = mMessage.GetOffset() + (HostSwap16(pointerValue) & kPointerLabelOffsetMask);
535 
536             // Go back through the `while(true)` loop to get the next label.
537         }
538         else
539         {
540             ExitNow(error = kErrorParse);
541         }
542     }
543 
544 exit:
545     return error;
546 }
547 
ReadLabel(char * aLabelBuffer,uint8_t & aLabelLength,bool aAllowDotCharInLabel) const548 Error Name::LabelIterator::ReadLabel(char *aLabelBuffer, uint8_t &aLabelLength, bool aAllowDotCharInLabel) const
549 {
550     Error error;
551 
552     VerifyOrExit(mLabelLength < aLabelLength, error = kErrorNoBufs);
553 
554     SuccessOrExit(error = mMessage.Read(mLabelStartOffset, aLabelBuffer, mLabelLength));
555     aLabelBuffer[mLabelLength] = kNullChar;
556     aLabelLength               = mLabelLength;
557 
558     if (!aAllowDotCharInLabel)
559     {
560         VerifyOrExit(StringFind(aLabelBuffer, kLabelSeperatorChar) == nullptr, error = kErrorParse);
561     }
562 
563 exit:
564     return error;
565 }
566 
CompareLabel(const char * & aName,bool aIsSingleLabel) const567 bool Name::LabelIterator::CompareLabel(const char *&aName, bool aIsSingleLabel) const
568 {
569     // This method compares the current label in the iterator with the
570     // `aName` string. `aIsSingleLabel` indicates whether `aName` is a
571     // single label, or a sequence of labels separated by dot '.' char.
572     // If the label matches `aName`, then `aName` pointer is moved
573     // forward to the start of the next label (skipping over the `.`
574     // char). This method returns `true` when the labels match, `false`
575     // otherwise.
576 
577     bool matches = false;
578 
579     VerifyOrExit(StringLength(aName, mLabelLength) == mLabelLength);
580     matches = mMessage.CompareBytes(mLabelStartOffset, aName, mLabelLength);
581 
582     VerifyOrExit(matches);
583 
584     aName += mLabelLength;
585 
586     // If `aName` is a single label, we should be also at the end of the
587     // `aName` string. Otherwise, we should see either null or dot '.'
588     // character (in case `aName` contains multiple labels).
589 
590     matches = (*aName == kNullChar);
591 
592     if (!aIsSingleLabel && (*aName == kLabelSeperatorChar))
593     {
594         matches = true;
595         aName++;
596     }
597 
598 exit:
599     return matches;
600 }
601 
CompareLabel(const LabelIterator & aOtherIterator) const602 bool Name::LabelIterator::CompareLabel(const LabelIterator &aOtherIterator) const
603 {
604     // This method compares the current label in the iterator with the
605     // label from another iterator.
606 
607     return (mLabelLength == aOtherIterator.mLabelLength) &&
608            mMessage.CompareBytes(mLabelStartOffset, aOtherIterator.mMessage, aOtherIterator.mLabelStartOffset,
609                                  mLabelLength);
610 }
611 
AppendLabel(Message & aMessage) const612 Error Name::LabelIterator::AppendLabel(Message &aMessage) const
613 {
614     // This method reads and appends the current label in the iterator
615     // to `aMessage`.
616 
617     Error error;
618 
619     VerifyOrExit((0 < mLabelLength) && (mLabelLength <= kMaxLabelLength), error = kErrorInvalidArgs);
620     SuccessOrExit(error = aMessage.Append(mLabelLength));
621     error = aMessage.AppendBytesFromMessage(mMessage, mLabelStartOffset, mLabelLength);
622 
623 exit:
624     return error;
625 }
626 
IsSubDomainOf(const char * aName,const char * aDomain)627 bool Name::IsSubDomainOf(const char *aName, const char *aDomain)
628 {
629     bool     match        = false;
630     uint16_t nameLength   = StringLength(aName, kMaxNameLength);
631     uint16_t domainLength = StringLength(aDomain, kMaxNameLength);
632 
633     if (nameLength > 0 && aName[nameLength - 1] == kLabelSeperatorChar)
634     {
635         --nameLength;
636     }
637 
638     if (domainLength > 0 && aDomain[domainLength - 1] == kLabelSeperatorChar)
639     {
640         --domainLength;
641     }
642 
643     VerifyOrExit(nameLength >= domainLength);
644     aName += nameLength - domainLength;
645 
646     if (nameLength > domainLength)
647     {
648         VerifyOrExit(aName[-1] == kLabelSeperatorChar);
649     }
650     VerifyOrExit(memcmp(aName, aDomain, domainLength) == 0);
651 
652     match = true;
653 
654 exit:
655     return match;
656 }
657 
ParseRecords(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords)658 Error ResourceRecord::ParseRecords(const Message &aMessage, uint16_t &aOffset, uint16_t aNumRecords)
659 {
660     Error error = kErrorNone;
661 
662     while (aNumRecords > 0)
663     {
664         ResourceRecord record;
665 
666         SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
667         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
668         aOffset += static_cast<uint16_t>(record.GetSize());
669         aNumRecords--;
670     }
671 
672 exit:
673     return error;
674 }
675 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t & aNumRecords,const Name & aName)676 Error ResourceRecord::FindRecord(const Message &aMessage, uint16_t &aOffset, uint16_t &aNumRecords, const Name &aName)
677 {
678     Error error;
679 
680     while (aNumRecords > 0)
681     {
682         bool           matches = true;
683         ResourceRecord record;
684 
685         error = Name::CompareName(aMessage, aOffset, aName);
686 
687         switch (error)
688         {
689         case kErrorNone:
690             break;
691         case kErrorNotFound:
692             matches = false;
693             break;
694         default:
695             ExitNow();
696         }
697 
698         SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
699         aNumRecords--;
700         VerifyOrExit(!matches);
701         aOffset += static_cast<uint16_t>(record.GetSize());
702     }
703 
704     error = kErrorNotFound;
705 
706 exit:
707     return error;
708 }
709 
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords,uint16_t aIndex,const Name & aName,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)710 Error ResourceRecord::FindRecord(const Message & aMessage,
711                                  uint16_t &      aOffset,
712                                  uint16_t        aNumRecords,
713                                  uint16_t        aIndex,
714                                  const Name &    aName,
715                                  uint16_t        aType,
716                                  ResourceRecord &aRecord,
717                                  uint16_t        aMinRecordSize)
718 {
719     // This static method searches in `aMessage` starting from `aOffset`
720     // up to maximum of `aNumRecords`, for the `(aIndex+1)`th
721     // occurrence of a resource record of type `aType` with record name
722     // matching `aName`. It also verifies that the record size is larger
723     // than `aMinRecordSize`. If found, `aMinRecordSize` bytes from the
724     // record are read and copied into `aRecord`. In this case `aOffset`
725     // is updated to point to the last record byte read from the message
726     // (so that the caller can read any remaining fields in the record
727     // data).
728 
729     Error    error;
730     uint16_t offset = aOffset;
731     uint16_t recordOffset;
732 
733     while (aNumRecords > 0)
734     {
735         SuccessOrExit(error = FindRecord(aMessage, offset, aNumRecords, aName));
736 
737         // Save the offset to start of `ResourceRecord` fields.
738         recordOffset = offset;
739 
740         error = ReadRecord(aMessage, offset, aType, aRecord, aMinRecordSize);
741 
742         if (error == kErrorNotFound)
743         {
744             // `ReadRecord()` already updates the `offset` to skip
745             // over a non-matching record.
746             continue;
747         }
748 
749         SuccessOrExit(error);
750 
751         if (aIndex == 0)
752         {
753             aOffset = offset;
754             ExitNow();
755         }
756 
757         aIndex--;
758 
759         // Skip over the record.
760         offset = static_cast<uint16_t>(recordOffset + aRecord.GetSize());
761     }
762 
763     error = kErrorNotFound;
764 
765 exit:
766     return error;
767 }
768 
ReadRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)769 Error ResourceRecord::ReadRecord(const Message & aMessage,
770                                  uint16_t &      aOffset,
771                                  uint16_t        aType,
772                                  ResourceRecord &aRecord,
773                                  uint16_t        aMinRecordSize)
774 {
775     // This static method tries to read a matching resource record of a
776     // given type and a minimum record size from a message. The `aType`
777     // value of `kTypeAny` matches any type.  If the record in the
778     // message does not match, it skips over the record. Please see
779     // `ReadRecord<RecordType>()` for more details.
780 
781     Error          error;
782     ResourceRecord record;
783 
784     SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
785 
786     if (((aType == kTypeAny) || (record.GetType() == aType)) && (record.GetSize() >= aMinRecordSize))
787     {
788         IgnoreError(aMessage.Read(aOffset, &aRecord, aMinRecordSize));
789         aOffset += aMinRecordSize;
790     }
791     else
792     {
793         // Skip over the entire record.
794         aOffset += static_cast<uint16_t>(record.GetSize());
795         error = kErrorNotFound;
796     }
797 
798 exit:
799     return error;
800 }
801 
ReadName(const Message & aMessage,uint16_t & aOffset,uint16_t aStartOffset,char * aNameBuffer,uint16_t aNameBufferSize,bool aSkipRecord) const802 Error ResourceRecord::ReadName(const Message &aMessage,
803                                uint16_t &     aOffset,
804                                uint16_t       aStartOffset,
805                                char *         aNameBuffer,
806                                uint16_t       aNameBufferSize,
807                                bool           aSkipRecord) const
808 {
809     // This protected method parses and reads a name field in a record
810     // from a message. It is intended only for sub-classes of
811     // `ResourceRecord`.
812     //
813     // On input `aOffset` gives the offset in `aMessage` to the start of
814     // name field. `aStartOffset` gives the offset to the start of the
815     // `ResourceRecord`. `aSkipRecord` indicates whether to skip over
816     // the entire resource record or just the read name. On exit, when
817     // successfully read, `aOffset` is updated to either point after the
818     // end of record or after the the name field.
819     //
820     // When read successfully, this method returns `kErrorNone`. On a
821     // parse error (invalid format) returns `kErrorParse`. If the
822     // name does not fit in the given name buffer it returns
823     // `kErrorNoBufs`
824 
825     Error error = kErrorNone;
826 
827     SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
828     VerifyOrExit(aOffset <= aStartOffset + GetSize(), error = kErrorParse);
829 
830     VerifyOrExit(aSkipRecord);
831     aOffset = aStartOffset;
832     error   = SkipRecord(aMessage, aOffset);
833 
834 exit:
835     return error;
836 }
837 
SkipRecord(const Message & aMessage,uint16_t & aOffset) const838 Error ResourceRecord::SkipRecord(const Message &aMessage, uint16_t &aOffset) const
839 {
840     // This protected method parses and skips over a resource record
841     // in a message.
842     //
843     // On input `aOffset` gives the offset in `aMessage` to the start of
844     // the `ResourceRecord`. On exit, when successfully parsed, `aOffset`
845     // is updated to point to byte after the entire record.
846 
847     Error error;
848 
849     SuccessOrExit(error = CheckRecord(aMessage, aOffset));
850     aOffset += static_cast<uint16_t>(GetSize());
851 
852 exit:
853     return error;
854 }
855 
CheckRecord(const Message & aMessage,uint16_t aOffset) const856 Error ResourceRecord::CheckRecord(const Message &aMessage, uint16_t aOffset) const
857 {
858     // This method checks that the entire record (including record data)
859     // is present in `aMessage` at `aOffset` (pointing to the start of
860     // the `ResourceRecord` in `aMessage`).
861 
862     return (aOffset + GetSize() <= aMessage.GetLength()) ? kErrorNone : kErrorParse;
863 }
864 
ReadFrom(const Message & aMessage,uint16_t aOffset)865 Error ResourceRecord::ReadFrom(const Message &aMessage, uint16_t aOffset)
866 {
867     // This method reads the `ResourceRecord` from `aMessage` at
868     // `aOffset`. It verifies that the entire record (including record
869     // data) is present in the message.
870 
871     Error error;
872 
873     SuccessOrExit(error = aMessage.Read(aOffset, *this));
874     error = CheckRecord(aMessage, aOffset);
875 
876 exit:
877     return error;
878 }
879 
Init(const uint8_t * aTxtData,uint16_t aTxtDataLength)880 void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength)
881 {
882     SetTxtData(aTxtData);
883     SetTxtDataLength(aTxtDataLength);
884     SetTxtDataPosition(0);
885 }
886 
GetNextEntry(TxtEntry & aEntry)887 Error TxtEntry::Iterator::GetNextEntry(TxtEntry &aEntry)
888 {
889     Error       error = kErrorNone;
890     uint8_t     length;
891     uint8_t     index;
892     const char *cur;
893     char *      keyBuffer = GetKeyBuffer();
894 
895     static_assert(sizeof(mChar) == TxtEntry::kMaxKeyLength + 1, "KeyBuffer cannot fit the max key length");
896 
897     VerifyOrExit(GetTxtData() != nullptr, error = kErrorParse);
898 
899     aEntry.mKey = keyBuffer;
900 
901     while ((cur = GetTxtData() + GetTxtDataPosition()) < GetTxtDataEnd())
902     {
903         length = static_cast<uint8_t>(*cur);
904 
905         cur++;
906         VerifyOrExit(cur + length <= GetTxtDataEnd(), error = kErrorParse);
907         IncreaseTxtDataPosition(sizeof(uint8_t) + length);
908 
909         // Silently skip over an empty string or if the string starts with
910         // a `=` character (i.e., missing key) - RFC 6763 - section 6.4.
911 
912         if ((length == 0) || (cur[0] == kKeyValueSeparator))
913         {
914             continue;
915         }
916 
917         for (index = 0; index < length; index++)
918         {
919             if (cur[index] == kKeyValueSeparator)
920             {
921                 keyBuffer[index++]  = kNullChar; // Increment index to skip over `=`.
922                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(&cur[index]);
923                 aEntry.mValueLength = length - index;
924                 ExitNow();
925             }
926 
927             if (index >= kMaxKeyLength)
928             {
929                 // The key is larger than recommended max key length.
930                 // In this case, we return the full encoded string in
931                 // `mValue` and `mValueLength` and set `mKey` to
932                 // `nullptr`.
933 
934                 aEntry.mKey         = nullptr;
935                 aEntry.mValue       = reinterpret_cast<const uint8_t *>(cur);
936                 aEntry.mValueLength = length;
937                 ExitNow();
938             }
939 
940             keyBuffer[index] = cur[index];
941         }
942 
943         // If we reach the end of the string without finding `=` then
944         // it is a boolean key attribute (encoded as "key").
945 
946         keyBuffer[index]    = kNullChar;
947         aEntry.mValue       = nullptr;
948         aEntry.mValueLength = 0;
949         ExitNow();
950     }
951 
952     error = kErrorNotFound;
953 
954 exit:
955     return error;
956 }
957 
AppendTo(Message & aMessage) const958 Error TxtEntry::AppendTo(Message &aMessage) const
959 {
960     Error    error = kErrorNone;
961     uint16_t keyLength;
962     char     separator = kKeyValueSeparator;
963 
964     if (mKey == nullptr)
965     {
966         VerifyOrExit((mValue != nullptr) && (mValueLength != 0));
967         error = aMessage.AppendBytes(mValue, mValueLength);
968         ExitNow();
969     }
970 
971     keyLength = StringLength(mKey, static_cast<uint16_t>(kMaxKeyValueEncodedSize) + 1);
972 
973     VerifyOrExit(kMinKeyLength <= keyLength, error = kErrorInvalidArgs);
974 
975     if (mValue == nullptr)
976     {
977         // Treat as a boolean attribute and encoded as "key" (with no `=`).
978         SuccessOrExit(error = aMessage.Append<uint8_t>(static_cast<uint8_t>(keyLength)));
979         error = aMessage.AppendBytes(mKey, keyLength);
980         ExitNow();
981     }
982 
983     // Treat as key/value and encode as "key=value", value may be empty.
984 
985     VerifyOrExit(mValueLength + keyLength + sizeof(char) <= kMaxKeyValueEncodedSize, error = kErrorInvalidArgs);
986 
987     SuccessOrExit(error = aMessage.Append<uint8_t>(static_cast<uint8_t>(keyLength + mValueLength + sizeof(char))));
988     SuccessOrExit(error = aMessage.AppendBytes(mKey, keyLength));
989     SuccessOrExit(error = aMessage.Append(separator));
990     error = aMessage.AppendBytes(mValue, mValueLength);
991 
992 exit:
993     return error;
994 }
995 
AppendEntries(const TxtEntry * aEntries,uint8_t aNumEntries,Message & aMessage)996 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint8_t aNumEntries, Message &aMessage)
997 {
998     Error    error       = kErrorNone;
999     uint16_t startOffset = aMessage.GetLength();
1000 
1001     for (uint8_t index = 0; index < aNumEntries; index++)
1002     {
1003         SuccessOrExit(error = aEntries[index].AppendTo(aMessage));
1004     }
1005 
1006     if (aMessage.GetLength() == startOffset)
1007     {
1008         error = aMessage.Append<uint8_t>(0);
1009     }
1010 
1011 exit:
1012     return error;
1013 }
1014 
IsValid(void) const1015 bool AaaaRecord::IsValid(void) const
1016 {
1017     return GetType() == Dns::ResourceRecord::kTypeAaaa && GetSize() == sizeof(*this);
1018 }
1019 
IsValid(void) const1020 bool KeyRecord::IsValid(void) const
1021 {
1022     return GetType() == Dns::ResourceRecord::kTypeKey;
1023 }
1024 
1025 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Init(void)1026 void Ecdsa256KeyRecord::Init(void)
1027 {
1028     KeyRecord::Init();
1029     SetAlgorithm(kAlgorithmEcdsaP256Sha256);
1030 }
1031 
IsValid(void) const1032 bool Ecdsa256KeyRecord::IsValid(void) const
1033 {
1034     return KeyRecord::IsValid() && GetLength() == sizeof(*this) - sizeof(ResourceRecord) &&
1035            GetAlgorithm() == kAlgorithmEcdsaP256Sha256;
1036 }
1037 #endif
1038 
IsValid(void) const1039 bool SigRecord::IsValid(void) const
1040 {
1041     return GetType() == Dns::ResourceRecord::kTypeSig && GetLength() >= sizeof(*this) - sizeof(ResourceRecord);
1042 }
1043 
IsValid(void) const1044 bool LeaseOption::IsValid(void) const
1045 {
1046     return GetLeaseInterval() <= GetKeyLeaseInterval();
1047 }
1048 
ReadPtrName(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const1049 Error PtrRecord::ReadPtrName(const Message &aMessage,
1050                              uint16_t &     aOffset,
1051                              char *         aLabelBuffer,
1052                              uint8_t        aLabelBufferSize,
1053                              char *         aNameBuffer,
1054                              uint16_t       aNameBufferSize) const
1055 {
1056     Error    error       = kErrorNone;
1057     uint16_t startOffset = aOffset - sizeof(PtrRecord); // start of `PtrRecord`.
1058 
1059     // Verify that the name is within the record data length.
1060     SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
1061     VerifyOrExit(aOffset <= startOffset + GetSize(), error = kErrorParse);
1062 
1063     aOffset = startOffset + sizeof(PtrRecord);
1064     SuccessOrExit(error = Name::ReadLabel(aMessage, aOffset, aLabelBuffer, aLabelBufferSize));
1065 
1066     if (aNameBuffer != nullptr)
1067     {
1068         SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
1069     }
1070 
1071     aOffset = startOffset;
1072     error   = SkipRecord(aMessage, aOffset);
1073 
1074 exit:
1075     return error;
1076 }
1077 
ReadTxtData(const Message & aMessage,uint16_t & aOffset,uint8_t * aTxtBuffer,uint16_t & aTxtBufferSize) const1078 Error TxtRecord::ReadTxtData(const Message &aMessage,
1079                              uint16_t &     aOffset,
1080                              uint8_t *      aTxtBuffer,
1081                              uint16_t &     aTxtBufferSize) const
1082 {
1083     Error error = kErrorNone;
1084 
1085     VerifyOrExit(GetLength() <= aTxtBufferSize, error = kErrorNoBufs);
1086     SuccessOrExit(error = aMessage.Read(aOffset, aTxtBuffer, GetLength()));
1087     VerifyOrExit(VerifyTxtData(aTxtBuffer, GetLength()), error = kErrorParse);
1088     aTxtBufferSize = GetLength();
1089     aOffset += GetLength();
1090 
1091 exit:
1092     return error;
1093 }
1094 
VerifyTxtData(const uint8_t * aTxtData,uint16_t aTxtLength)1095 bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength)
1096 {
1097     bool    valid          = false;
1098     uint8_t curEntryLength = 0;
1099 
1100     // Per RFC 1035, TXT-DATA MUST have one or more <character-string>s.
1101     VerifyOrExit(aTxtLength > 0);
1102 
1103     for (uint16_t i = 0; i < aTxtLength; ++i)
1104     {
1105         if (curEntryLength == 0)
1106         {
1107             curEntryLength = aTxtData[i];
1108         }
1109         else
1110         {
1111             --curEntryLength;
1112         }
1113     }
1114 
1115     valid = (curEntryLength == 0);
1116 
1117 exit:
1118     return valid;
1119 }
1120 
1121 } // namespace Dns
1122 } // namespace ot
1123