1 /*
2 * Copyright (c) 2020, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements generating and processing of DNS headers and helper functions/methods.
32 */
33
34 #include "dns_types.hpp"
35
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/instance.hpp"
39 #include "common/num_utils.hpp"
40 #include "common/random.hpp"
41 #include "common/string.hpp"
42
43 namespace ot {
44 namespace Dns {
45
46 using ot::Encoding::BigEndian::HostSwap16;
47
SetRandomMessageId(void)48 Error Header::SetRandomMessageId(void)
49 {
50 return Random::Crypto::FillBuffer(reinterpret_cast<uint8_t *>(&mMessageId), sizeof(mMessageId));
51 }
52
ResponseCodeToError(Response aResponse)53 Error Header::ResponseCodeToError(Response aResponse)
54 {
55 Error error = kErrorFailed;
56
57 switch (aResponse)
58 {
59 case kResponseSuccess:
60 error = kErrorNone;
61 break;
62
63 case kResponseFormatError: // Server unable to interpret request due to format error.
64 case kResponseBadName: // Bad name.
65 case kResponseBadTruncation: // Bad truncation.
66 case kResponseNotZone: // A name is not in the zone.
67 error = kErrorParse;
68 break;
69
70 case kResponseServerFailure: // Server encountered an internal failure.
71 error = kErrorFailed;
72 break;
73
74 case kResponseNameError: // Name that ought to exist, does not exists.
75 case kResponseRecordNotExists: // Some RRset that ought to exist, does not exist.
76 error = kErrorNotFound;
77 break;
78
79 case kResponseNotImplemented: // Server does not support the query type (OpCode).
80 case kDsoTypeNotImplemented: // DSO TLV type is not implemented.
81 error = kErrorNotImplemented;
82 break;
83
84 case kResponseBadAlg: // Bad algorithm.
85 error = kErrorNotCapable;
86 break;
87
88 case kResponseNameExists: // Some name that ought not to exist, does exist.
89 case kResponseRecordExists: // Some RRset that ought not to exist, does exist.
90 error = kErrorDuplicated;
91 break;
92
93 case kResponseRefused: // Server refused to perform operation for policy or security reasons.
94 case kResponseNotAuth: // Service is not authoritative for zone.
95 error = kErrorSecurity;
96 break;
97
98 default:
99 break;
100 }
101
102 return error;
103 }
104
AppendTo(Message & aMessage) const105 Error Name::AppendTo(Message &aMessage) const
106 {
107 Error error;
108
109 if (IsEmpty())
110 {
111 error = AppendTerminator(aMessage);
112 }
113 else if (IsFromCString())
114 {
115 error = AppendName(GetAsCString(), aMessage);
116 }
117 else
118 {
119 // Name is from a message. Read labels one by one from
120 // `mMessage` and and append each to the `aMessage`.
121
122 LabelIterator iterator(*mMessage, mOffset);
123
124 while (true)
125 {
126 error = iterator.GetNextLabel();
127
128 switch (error)
129 {
130 case kErrorNone:
131 SuccessOrExit(error = iterator.AppendLabel(aMessage));
132 break;
133
134 case kErrorNotFound:
135 // We reached the end of name successfully.
136 error = AppendTerminator(aMessage);
137
138 OT_FALL_THROUGH;
139
140 default:
141 ExitNow();
142 }
143 }
144 }
145
146 exit:
147 return error;
148 }
149
AppendLabel(const char * aLabel,Message & aMessage)150 Error Name::AppendLabel(const char *aLabel, Message &aMessage)
151 {
152 return AppendLabel(aLabel, static_cast<uint8_t>(StringLength(aLabel, kMaxLabelSize)), aMessage);
153 }
154
AppendLabel(const char * aLabel,uint8_t aLength,Message & aMessage)155 Error Name::AppendLabel(const char *aLabel, uint8_t aLength, Message &aMessage)
156 {
157 Error error = kErrorNone;
158
159 VerifyOrExit((0 < aLength) && (aLength <= kMaxLabelLength), error = kErrorInvalidArgs);
160
161 SuccessOrExit(error = aMessage.Append(aLength));
162 error = aMessage.AppendBytes(aLabel, aLength);
163
164 exit:
165 return error;
166 }
167
AppendMultipleLabels(const char * aLabels,Message & aMessage)168 Error Name::AppendMultipleLabels(const char *aLabels, Message &aMessage)
169 {
170 Error error = kErrorNone;
171 uint16_t index = 0;
172 uint16_t labelStartIndex = 0;
173 char ch;
174
175 VerifyOrExit(aLabels != nullptr);
176
177 do
178 {
179 ch = aLabels[index];
180
181 if ((ch == kNullChar) || (ch == kLabelSeparatorChar))
182 {
183 uint8_t labelLength = static_cast<uint8_t>(index - labelStartIndex);
184
185 if (labelLength == 0)
186 {
187 // Empty label (e.g., consecutive dots) is invalid, but we
188 // allow for two cases: (1) where `aLabels` ends with a dot
189 // (`labelLength` is zero but we are at end of `aLabels` string
190 // and `ch` is null char. (2) if `aLabels` is just "." (we
191 // see a dot at index 0, and index 1 is null char).
192
193 error =
194 ((ch == kNullChar) || ((index == 0) && (aLabels[1] == kNullChar))) ? kErrorNone : kErrorInvalidArgs;
195 ExitNow();
196 }
197
198 VerifyOrExit(index + 1 < kMaxEncodedLength, error = kErrorInvalidArgs);
199 SuccessOrExit(error = AppendLabel(&aLabels[labelStartIndex], labelLength, aMessage));
200
201 labelStartIndex = index + 1;
202 }
203
204 index++;
205
206 } while (ch != kNullChar);
207
208 exit:
209 return error;
210 }
211
AppendTerminator(Message & aMessage)212 Error Name::AppendTerminator(Message &aMessage)
213 {
214 uint8_t terminator = 0;
215
216 return aMessage.Append(terminator);
217 }
218
AppendPointerLabel(uint16_t aOffset,Message & aMessage)219 Error Name::AppendPointerLabel(uint16_t aOffset, Message &aMessage)
220 {
221 Error error;
222 uint16_t value;
223
224 #if OPENTHREAD_CONFIG_REFERENCE_DEVICE_ENABLE
225 if (!Instance::IsDnsNameCompressionEnabled())
226 {
227 // If "DNS name compression" mode is disabled, instead of
228 // appending the pointer label, read the name from the message
229 // and append it uncompressed. Note that the `aOffset` parameter
230 // in this method is given relative to the start of DNS header
231 // in `aMessage` (which `aMessage.GetOffset()` specifies).
232
233 error = Name(aMessage, aOffset + aMessage.GetOffset()).AppendTo(aMessage);
234 ExitNow();
235 }
236 #endif
237
238 // A pointer label takes the form of a two byte sequence as a
239 // `uint16_t` value. The first two bits are ones. This allows a
240 // pointer to be distinguished from a text label, since the text
241 // label must begin with two zero bits (note that labels are
242 // restricted to 63 octets or less). The next 14-bits specify
243 // an offset value relative to start of DNS header.
244
245 OT_ASSERT(aOffset < kPointerLabelTypeUint16);
246
247 value = HostSwap16(aOffset | kPointerLabelTypeUint16);
248
249 ExitNow(error = aMessage.Append(value));
250
251 exit:
252 return error;
253 }
254
AppendName(const char * aName,Message & aMessage)255 Error Name::AppendName(const char *aName, Message &aMessage)
256 {
257 Error error;
258
259 SuccessOrExit(error = AppendMultipleLabels(aName, aMessage));
260 error = AppendTerminator(aMessage);
261
262 exit:
263 return error;
264 }
265
ParseName(const Message & aMessage,uint16_t & aOffset)266 Error Name::ParseName(const Message &aMessage, uint16_t &aOffset)
267 {
268 Error error;
269 LabelIterator iterator(aMessage, aOffset);
270
271 while (true)
272 {
273 error = iterator.GetNextLabel();
274
275 switch (error)
276 {
277 case kErrorNone:
278 break;
279
280 case kErrorNotFound:
281 // We reached the end of name successfully.
282 aOffset = iterator.mNameEndOffset;
283 error = kErrorNone;
284
285 OT_FALL_THROUGH;
286
287 default:
288 ExitNow();
289 }
290 }
291
292 exit:
293 return error;
294 }
295
ReadLabel(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t & aLabelLength)296 Error Name::ReadLabel(const Message &aMessage, uint16_t &aOffset, char *aLabelBuffer, uint8_t &aLabelLength)
297 {
298 Error error;
299 LabelIterator iterator(aMessage, aOffset);
300
301 SuccessOrExit(error = iterator.GetNextLabel());
302 SuccessOrExit(error = iterator.ReadLabel(aLabelBuffer, aLabelLength, /* aAllowDotCharInLabel */ true));
303 aOffset = iterator.mNextLabelOffset;
304
305 exit:
306 return error;
307 }
308
ReadName(const Message & aMessage,uint16_t & aOffset,char * aNameBuffer,uint16_t aNameBufferSize)309 Error Name::ReadName(const Message &aMessage, uint16_t &aOffset, char *aNameBuffer, uint16_t aNameBufferSize)
310 {
311 Error error;
312 LabelIterator iterator(aMessage, aOffset);
313 bool firstLabel = true;
314 uint8_t labelLength;
315
316 while (true)
317 {
318 error = iterator.GetNextLabel();
319
320 switch (error)
321 {
322 case kErrorNone:
323
324 if (!firstLabel)
325 {
326 *aNameBuffer++ = kLabelSeparatorChar;
327 aNameBufferSize--;
328
329 // No need to check if we have reached end of the name buffer
330 // here since `iterator.ReadLabel()` would verify it.
331 }
332
333 labelLength = static_cast<uint8_t>(Min(static_cast<uint16_t>(kMaxLabelSize), aNameBufferSize));
334 SuccessOrExit(error = iterator.ReadLabel(aNameBuffer, labelLength, /* aAllowDotCharInLabel */ firstLabel));
335 aNameBuffer += labelLength;
336 aNameBufferSize -= labelLength;
337 firstLabel = false;
338 break;
339
340 case kErrorNotFound:
341 // We reach the end of name successfully. Always add a terminating dot
342 // at the end.
343 *aNameBuffer++ = kLabelSeparatorChar;
344 aNameBufferSize--;
345 VerifyOrExit(aNameBufferSize >= sizeof(uint8_t), error = kErrorNoBufs);
346 *aNameBuffer = kNullChar;
347 aOffset = iterator.mNameEndOffset;
348 error = kErrorNone;
349
350 OT_FALL_THROUGH;
351
352 default:
353 ExitNow();
354 }
355 }
356
357 exit:
358 return error;
359 }
360
CompareLabel(const Message & aMessage,uint16_t & aOffset,const char * aLabel)361 Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel)
362 {
363 Error error;
364 LabelIterator iterator(aMessage, aOffset);
365
366 SuccessOrExit(error = iterator.GetNextLabel());
367 VerifyOrExit(iterator.CompareLabel(aLabel, kIsSingleLabel), error = kErrorNotFound);
368 aOffset = iterator.mNextLabelOffset;
369
370 exit:
371 return error;
372 }
373
CompareName(const Message & aMessage,uint16_t & aOffset,const char * aName)374 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
375 {
376 Error error;
377 LabelIterator iterator(aMessage, aOffset);
378 bool matches = true;
379
380 if (*aName == kLabelSeparatorChar)
381 {
382 aName++;
383 VerifyOrExit(*aName == kNullChar, error = kErrorInvalidArgs);
384 }
385
386 while (true)
387 {
388 error = iterator.GetNextLabel();
389
390 switch (error)
391 {
392 case kErrorNone:
393 if (matches && !iterator.CompareLabel(aName, !kIsSingleLabel))
394 {
395 matches = false;
396 }
397
398 break;
399
400 case kErrorNotFound:
401 // We reached the end of the name in `aMessage`. We check if
402 // all the previous labels matched so far, and we are also
403 // at the end of `aName` string (see null char), then we
404 // return `kErrorNone` indicating a successful comparison
405 // (full match). Otherwise we return `kErrorNotFound` to
406 // indicate failed comparison.
407
408 if (matches && (*aName == kNullChar))
409 {
410 error = kErrorNone;
411 }
412
413 aOffset = iterator.mNameEndOffset;
414
415 OT_FALL_THROUGH;
416
417 default:
418 ExitNow();
419 }
420 }
421
422 exit:
423 return error;
424 }
425
CompareName(const Message & aMessage,uint16_t & aOffset,const Message & aMessage2,uint16_t aOffset2)426 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Message &aMessage2, uint16_t aOffset2)
427 {
428 Error error;
429 LabelIterator iterator(aMessage, aOffset);
430 LabelIterator iterator2(aMessage2, aOffset2);
431 bool matches = true;
432
433 while (true)
434 {
435 error = iterator.GetNextLabel();
436
437 switch (error)
438 {
439 case kErrorNone:
440 // If all the previous labels matched so far, then verify
441 // that we can get the next label on `iterator2` and that it
442 // matches the label from `iterator`.
443 if (matches && (iterator2.GetNextLabel() != kErrorNone || !iterator.CompareLabel(iterator2)))
444 {
445 matches = false;
446 }
447
448 break;
449
450 case kErrorNotFound:
451 // We reached the end of the name in `aMessage`. We check
452 // that `iterator2` is also at its end, and if all previous
453 // labels matched we return `kErrorNone`.
454
455 if (matches && (iterator2.GetNextLabel() == kErrorNotFound))
456 {
457 error = kErrorNone;
458 }
459
460 aOffset = iterator.mNameEndOffset;
461
462 OT_FALL_THROUGH;
463
464 default:
465 ExitNow();
466 }
467 }
468
469 exit:
470 return error;
471 }
472
CompareName(const Message & aMessage,uint16_t & aOffset,const Name & aName)473 Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const Name &aName)
474 {
475 return aName.IsFromCString()
476 ? CompareName(aMessage, aOffset, aName.mString)
477 : (aName.IsFromMessage() ? CompareName(aMessage, aOffset, *aName.mMessage, aName.mOffset)
478 : ParseName(aMessage, aOffset));
479 }
480
GetNextLabel(void)481 Error Name::LabelIterator::GetNextLabel(void)
482 {
483 Error error;
484
485 while (true)
486 {
487 uint8_t labelLength;
488 uint8_t labelType;
489
490 SuccessOrExit(error = mMessage.Read(mNextLabelOffset, labelLength));
491
492 labelType = labelLength & kLabelTypeMask;
493
494 if (labelType == kTextLabelType)
495 {
496 if (labelLength == 0)
497 {
498 // Zero label length indicates end of a name.
499
500 if (!IsEndOffsetSet())
501 {
502 mNameEndOffset = mNextLabelOffset + sizeof(uint8_t);
503 }
504
505 ExitNow(error = kErrorNotFound);
506 }
507
508 mLabelStartOffset = mNextLabelOffset + sizeof(uint8_t);
509 mLabelLength = labelLength;
510 mNextLabelOffset = mLabelStartOffset + labelLength;
511 ExitNow();
512 }
513 else if (labelType == kPointerLabelType)
514 {
515 // A pointer label takes the form of a two byte sequence as a
516 // `uint16_t` value. The first two bits are ones. The next 14 bits
517 // specify an offset value from the start of the DNS header.
518
519 uint16_t pointerValue;
520 uint16_t nextLabelOffset;
521
522 SuccessOrExit(error = mMessage.Read(mNextLabelOffset, pointerValue));
523
524 if (!IsEndOffsetSet())
525 {
526 mNameEndOffset = mNextLabelOffset + sizeof(uint16_t);
527 }
528
529 // `mMessage.GetOffset()` must point to the start of the
530 // DNS header.
531 nextLabelOffset = mMessage.GetOffset() + (HostSwap16(pointerValue) & kPointerLabelOffsetMask);
532 VerifyOrExit(nextLabelOffset < mNextLabelOffset, error = kErrorParse);
533 mNextLabelOffset = nextLabelOffset;
534
535 // Go back through the `while(true)` loop to get the next label.
536 }
537 else
538 {
539 ExitNow(error = kErrorParse);
540 }
541 }
542
543 exit:
544 return error;
545 }
546
ReadLabel(char * aLabelBuffer,uint8_t & aLabelLength,bool aAllowDotCharInLabel) const547 Error Name::LabelIterator::ReadLabel(char *aLabelBuffer, uint8_t &aLabelLength, bool aAllowDotCharInLabel) const
548 {
549 Error error;
550
551 VerifyOrExit(mLabelLength < aLabelLength, error = kErrorNoBufs);
552
553 SuccessOrExit(error = mMessage.Read(mLabelStartOffset, aLabelBuffer, mLabelLength));
554 aLabelBuffer[mLabelLength] = kNullChar;
555 aLabelLength = mLabelLength;
556
557 if (!aAllowDotCharInLabel)
558 {
559 VerifyOrExit(StringFind(aLabelBuffer, kLabelSeparatorChar) == nullptr, error = kErrorParse);
560 }
561
562 exit:
563 return error;
564 }
565
CaseInsensitiveMatch(uint8_t aFirst,uint8_t aSecond)566 bool Name::LabelIterator::CaseInsensitiveMatch(uint8_t aFirst, uint8_t aSecond)
567 {
568 return ToLowercase(static_cast<char>(aFirst)) == ToLowercase(static_cast<char>(aSecond));
569 }
570
CompareLabel(const char * & aName,bool aIsSingleLabel) const571 bool Name::LabelIterator::CompareLabel(const char *&aName, bool aIsSingleLabel) const
572 {
573 // This method compares the current label in the iterator with the
574 // `aName` string. `aIsSingleLabel` indicates whether `aName` is a
575 // single label, or a sequence of labels separated by dot '.' char.
576 // If the label matches `aName`, then `aName` pointer is moved
577 // forward to the start of the next label (skipping over the `.`
578 // char). This method returns `true` when the labels match, `false`
579 // otherwise.
580
581 bool matches = false;
582
583 VerifyOrExit(StringLength(aName, mLabelLength) == mLabelLength);
584 matches = mMessage.CompareBytes(mLabelStartOffset, aName, mLabelLength, CaseInsensitiveMatch);
585
586 VerifyOrExit(matches);
587
588 aName += mLabelLength;
589
590 // If `aName` is a single label, we should be also at the end of the
591 // `aName` string. Otherwise, we should see either null or dot '.'
592 // character (in case `aName` contains multiple labels).
593
594 matches = (*aName == kNullChar);
595
596 if (!aIsSingleLabel && (*aName == kLabelSeparatorChar))
597 {
598 matches = true;
599 aName++;
600 }
601
602 exit:
603 return matches;
604 }
605
CompareLabel(const LabelIterator & aOtherIterator) const606 bool Name::LabelIterator::CompareLabel(const LabelIterator &aOtherIterator) const
607 {
608 // This method compares the current label in the iterator with the
609 // label from another iterator.
610
611 return (mLabelLength == aOtherIterator.mLabelLength) &&
612 mMessage.CompareBytes(mLabelStartOffset, aOtherIterator.mMessage, aOtherIterator.mLabelStartOffset,
613 mLabelLength, CaseInsensitiveMatch);
614 }
615
AppendLabel(Message & aMessage) const616 Error Name::LabelIterator::AppendLabel(Message &aMessage) const
617 {
618 // This method reads and appends the current label in the iterator
619 // to `aMessage`.
620
621 Error error;
622
623 VerifyOrExit((0 < mLabelLength) && (mLabelLength <= kMaxLabelLength), error = kErrorInvalidArgs);
624 SuccessOrExit(error = aMessage.Append(mLabelLength));
625 error = aMessage.AppendBytesFromMessage(mMessage, mLabelStartOffset, mLabelLength);
626
627 exit:
628 return error;
629 }
630
ExtractLabels(const char * aName,const char * aSuffixName,char * aLabels,uint16_t aLabelsSize)631 Error Name::ExtractLabels(const char *aName, const char *aSuffixName, char *aLabels, uint16_t aLabelsSize)
632 {
633 Error error = kErrorParse;
634 uint16_t nameLength = StringLength(aName, kMaxNameSize);
635 uint16_t suffixLength = StringLength(aSuffixName, kMaxNameSize);
636 const char *suffixStart;
637
638 VerifyOrExit(nameLength < kMaxNameSize);
639 VerifyOrExit(suffixLength < kMaxNameSize);
640
641 VerifyOrExit(nameLength > suffixLength);
642
643 suffixStart = aName + nameLength - suffixLength;
644 VerifyOrExit(StringMatch(suffixStart, aSuffixName, kStringCaseInsensitiveMatch));
645 suffixStart--;
646 VerifyOrExit(*suffixStart == kLabelSeparatorChar);
647
648 // Determine the labels length to copy
649 nameLength -= (suffixLength + 1);
650 VerifyOrExit(nameLength < aLabelsSize, error = kErrorNoBufs);
651
652 memcpy(aLabels, aName, nameLength);
653 aLabels[nameLength] = kNullChar;
654 error = kErrorNone;
655
656 exit:
657 return error;
658 }
659
IsSubDomainOf(const char * aName,const char * aDomain)660 bool Name::IsSubDomainOf(const char *aName, const char *aDomain)
661 {
662 bool match = false;
663 bool nameEndsWithDot = false;
664 bool domainEndsWithDot = false;
665 uint16_t nameLength = StringLength(aName, kMaxNameLength);
666 uint16_t domainLength = StringLength(aDomain, kMaxNameLength);
667
668 if (nameLength > 0 && aName[nameLength - 1] == kLabelSeparatorChar)
669 {
670 nameEndsWithDot = true;
671 --nameLength;
672 }
673
674 if (domainLength > 0 && aDomain[domainLength - 1] == kLabelSeparatorChar)
675 {
676 domainEndsWithDot = true;
677 --domainLength;
678 }
679
680 VerifyOrExit(nameLength >= domainLength);
681
682 aName += nameLength - domainLength;
683
684 if (nameLength > domainLength)
685 {
686 VerifyOrExit(aName[-1] == kLabelSeparatorChar);
687 }
688
689 // This method allows either `aName` or `aDomain` to include or
690 // exclude the last `.` character. If both include it or if both
691 // do not, we do a full comparison using `StringMatch()`.
692 // Otherwise (i.e., when one includes and the other one does not)
693 // we use `StringStartWith()` to allow the extra `.` character.
694
695 if (nameEndsWithDot == domainEndsWithDot)
696 {
697 match = StringMatch(aName, aDomain, kStringCaseInsensitiveMatch);
698 }
699 else if (nameEndsWithDot)
700 {
701 // `aName` ends with dot, but `aDomain` does not.
702 match = StringStartsWith(aName, aDomain, kStringCaseInsensitiveMatch);
703 }
704 else
705 {
706 // `aDomain` ends with dot, but `aName` does not.
707 match = StringStartsWith(aDomain, aName, kStringCaseInsensitiveMatch);
708 }
709
710 exit:
711 return match;
712 }
713
IsSameDomain(const char * aDomain1,const char * aDomain2)714 bool Name::IsSameDomain(const char *aDomain1, const char *aDomain2)
715 {
716 return IsSubDomainOf(aDomain1, aDomain2) && IsSubDomainOf(aDomain2, aDomain1);
717 }
718
ParseRecords(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords)719 Error ResourceRecord::ParseRecords(const Message &aMessage, uint16_t &aOffset, uint16_t aNumRecords)
720 {
721 Error error = kErrorNone;
722
723 while (aNumRecords > 0)
724 {
725 ResourceRecord record;
726
727 SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
728 SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
729 aOffset += static_cast<uint16_t>(record.GetSize());
730 aNumRecords--;
731 }
732
733 exit:
734 return error;
735 }
736
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t & aNumRecords,const Name & aName)737 Error ResourceRecord::FindRecord(const Message &aMessage, uint16_t &aOffset, uint16_t &aNumRecords, const Name &aName)
738 {
739 Error error;
740
741 while (aNumRecords > 0)
742 {
743 bool matches = true;
744 ResourceRecord record;
745
746 error = Name::CompareName(aMessage, aOffset, aName);
747
748 switch (error)
749 {
750 case kErrorNone:
751 break;
752 case kErrorNotFound:
753 matches = false;
754 break;
755 default:
756 ExitNow();
757 }
758
759 SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
760 aNumRecords--;
761 VerifyOrExit(!matches);
762 aOffset += static_cast<uint16_t>(record.GetSize());
763 }
764
765 error = kErrorNotFound;
766
767 exit:
768 return error;
769 }
770
FindRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aNumRecords,uint16_t aIndex,const Name & aName,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)771 Error ResourceRecord::FindRecord(const Message &aMessage,
772 uint16_t &aOffset,
773 uint16_t aNumRecords,
774 uint16_t aIndex,
775 const Name &aName,
776 uint16_t aType,
777 ResourceRecord &aRecord,
778 uint16_t aMinRecordSize)
779 {
780 // This static method searches in `aMessage` starting from `aOffset`
781 // up to maximum of `aNumRecords`, for the `(aIndex+1)`th
782 // occurrence of a resource record of type `aType` with record name
783 // matching `aName`. It also verifies that the record size is larger
784 // than `aMinRecordSize`. If found, `aMinRecordSize` bytes from the
785 // record are read and copied into `aRecord`. In this case `aOffset`
786 // is updated to point to the last record byte read from the message
787 // (so that the caller can read any remaining fields in the record
788 // data).
789
790 Error error;
791 uint16_t offset = aOffset;
792 uint16_t recordOffset;
793
794 while (aNumRecords > 0)
795 {
796 SuccessOrExit(error = FindRecord(aMessage, offset, aNumRecords, aName));
797
798 // Save the offset to start of `ResourceRecord` fields.
799 recordOffset = offset;
800
801 error = ReadRecord(aMessage, offset, aType, aRecord, aMinRecordSize);
802
803 if (error == kErrorNotFound)
804 {
805 // `ReadRecord()` already updates the `offset` to skip
806 // over a non-matching record.
807 continue;
808 }
809
810 SuccessOrExit(error);
811
812 if (aIndex == 0)
813 {
814 aOffset = offset;
815 ExitNow();
816 }
817
818 aIndex--;
819
820 // Skip over the record.
821 offset = static_cast<uint16_t>(recordOffset + aRecord.GetSize());
822 }
823
824 error = kErrorNotFound;
825
826 exit:
827 return error;
828 }
829
ReadRecord(const Message & aMessage,uint16_t & aOffset,uint16_t aType,ResourceRecord & aRecord,uint16_t aMinRecordSize)830 Error ResourceRecord::ReadRecord(const Message &aMessage,
831 uint16_t &aOffset,
832 uint16_t aType,
833 ResourceRecord &aRecord,
834 uint16_t aMinRecordSize)
835 {
836 // This static method tries to read a matching resource record of a
837 // given type and a minimum record size from a message. The `aType`
838 // value of `kTypeAny` matches any type. If the record in the
839 // message does not match, it skips over the record. Please see
840 // `ReadRecord<RecordType>()` for more details.
841
842 Error error;
843 ResourceRecord record;
844
845 SuccessOrExit(error = record.ReadFrom(aMessage, aOffset));
846
847 if (((aType == kTypeAny) || (record.GetType() == aType)) && (record.GetSize() >= aMinRecordSize))
848 {
849 IgnoreError(aMessage.Read(aOffset, &aRecord, aMinRecordSize));
850 aOffset += aMinRecordSize;
851 }
852 else
853 {
854 // Skip over the entire record.
855 aOffset += static_cast<uint16_t>(record.GetSize());
856 error = kErrorNotFound;
857 }
858
859 exit:
860 return error;
861 }
862
ReadName(const Message & aMessage,uint16_t & aOffset,uint16_t aStartOffset,char * aNameBuffer,uint16_t aNameBufferSize,bool aSkipRecord) const863 Error ResourceRecord::ReadName(const Message &aMessage,
864 uint16_t &aOffset,
865 uint16_t aStartOffset,
866 char *aNameBuffer,
867 uint16_t aNameBufferSize,
868 bool aSkipRecord) const
869 {
870 // This protected method parses and reads a name field in a record
871 // from a message. It is intended only for sub-classes of
872 // `ResourceRecord`.
873 //
874 // On input `aOffset` gives the offset in `aMessage` to the start of
875 // name field. `aStartOffset` gives the offset to the start of the
876 // `ResourceRecord`. `aSkipRecord` indicates whether to skip over
877 // the entire resource record or just the read name. On exit, when
878 // successfully read, `aOffset` is updated to either point after the
879 // end of record or after the the name field.
880 //
881 // When read successfully, this method returns `kErrorNone`. On a
882 // parse error (invalid format) returns `kErrorParse`. If the
883 // name does not fit in the given name buffer it returns
884 // `kErrorNoBufs`
885
886 Error error = kErrorNone;
887
888 SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
889 VerifyOrExit(aOffset <= aStartOffset + GetSize(), error = kErrorParse);
890
891 VerifyOrExit(aSkipRecord);
892 aOffset = aStartOffset;
893 error = SkipRecord(aMessage, aOffset);
894
895 exit:
896 return error;
897 }
898
SkipRecord(const Message & aMessage,uint16_t & aOffset) const899 Error ResourceRecord::SkipRecord(const Message &aMessage, uint16_t &aOffset) const
900 {
901 // This protected method parses and skips over a resource record
902 // in a message.
903 //
904 // On input `aOffset` gives the offset in `aMessage` to the start of
905 // the `ResourceRecord`. On exit, when successfully parsed, `aOffset`
906 // is updated to point to byte after the entire record.
907
908 Error error;
909
910 SuccessOrExit(error = CheckRecord(aMessage, aOffset));
911 aOffset += static_cast<uint16_t>(GetSize());
912
913 exit:
914 return error;
915 }
916
CheckRecord(const Message & aMessage,uint16_t aOffset) const917 Error ResourceRecord::CheckRecord(const Message &aMessage, uint16_t aOffset) const
918 {
919 // This method checks that the entire record (including record data)
920 // is present in `aMessage` at `aOffset` (pointing to the start of
921 // the `ResourceRecord` in `aMessage`).
922
923 return (aOffset + GetSize() <= aMessage.GetLength()) ? kErrorNone : kErrorParse;
924 }
925
ReadFrom(const Message & aMessage,uint16_t aOffset)926 Error ResourceRecord::ReadFrom(const Message &aMessage, uint16_t aOffset)
927 {
928 // This method reads the `ResourceRecord` from `aMessage` at
929 // `aOffset`. It verifies that the entire record (including record
930 // data) is present in the message.
931
932 Error error;
933
934 SuccessOrExit(error = aMessage.Read(aOffset, *this));
935 error = CheckRecord(aMessage, aOffset);
936
937 exit:
938 return error;
939 }
940
Init(const uint8_t * aTxtData,uint16_t aTxtDataLength)941 void TxtEntry::Iterator::Init(const uint8_t *aTxtData, uint16_t aTxtDataLength)
942 {
943 SetTxtData(aTxtData);
944 SetTxtDataLength(aTxtDataLength);
945 SetTxtDataPosition(0);
946 }
947
GetNextEntry(TxtEntry & aEntry)948 Error TxtEntry::Iterator::GetNextEntry(TxtEntry &aEntry)
949 {
950 Error error = kErrorNone;
951 uint8_t length;
952 uint8_t index;
953 const char *cur;
954 char *keyBuffer = GetKeyBuffer();
955
956 static_assert(sizeof(mChar) == TxtEntry::kMaxKeyLength + 1, "KeyBuffer cannot fit the max key length");
957
958 VerifyOrExit(GetTxtData() != nullptr, error = kErrorParse);
959
960 aEntry.mKey = keyBuffer;
961
962 while ((cur = GetTxtData() + GetTxtDataPosition()) < GetTxtDataEnd())
963 {
964 length = static_cast<uint8_t>(*cur);
965
966 cur++;
967 VerifyOrExit(cur + length <= GetTxtDataEnd(), error = kErrorParse);
968 IncreaseTxtDataPosition(sizeof(uint8_t) + length);
969
970 // Silently skip over an empty string or if the string starts with
971 // a `=` character (i.e., missing key) - RFC 6763 - section 6.4.
972
973 if ((length == 0) || (cur[0] == kKeyValueSeparator))
974 {
975 continue;
976 }
977
978 for (index = 0; index < length; index++)
979 {
980 if (cur[index] == kKeyValueSeparator)
981 {
982 keyBuffer[index++] = kNullChar; // Increment index to skip over `=`.
983 aEntry.mValue = reinterpret_cast<const uint8_t *>(&cur[index]);
984 aEntry.mValueLength = length - index;
985 ExitNow();
986 }
987
988 if (index >= kMaxKeyLength)
989 {
990 // The key is larger than recommended max key length.
991 // In this case, we return the full encoded string in
992 // `mValue` and `mValueLength` and set `mKey` to
993 // `nullptr`.
994
995 aEntry.mKey = nullptr;
996 aEntry.mValue = reinterpret_cast<const uint8_t *>(cur);
997 aEntry.mValueLength = length;
998 ExitNow();
999 }
1000
1001 keyBuffer[index] = cur[index];
1002 }
1003
1004 // If we reach the end of the string without finding `=` then
1005 // it is a boolean key attribute (encoded as "key").
1006
1007 keyBuffer[index] = kNullChar;
1008 aEntry.mValue = nullptr;
1009 aEntry.mValueLength = 0;
1010 ExitNow();
1011 }
1012
1013 error = kErrorNotFound;
1014
1015 exit:
1016 return error;
1017 }
1018
AppendTo(Message & aMessage) const1019 Error TxtEntry::AppendTo(Message &aMessage) const
1020 {
1021 Appender appender(aMessage);
1022
1023 return AppendTo(appender);
1024 }
1025
AppendTo(Appender & aAppender) const1026 Error TxtEntry::AppendTo(Appender &aAppender) const
1027 {
1028 Error error = kErrorNone;
1029 uint16_t keyLength;
1030 char separator = kKeyValueSeparator;
1031
1032 if (mKey == nullptr)
1033 {
1034 VerifyOrExit((mValue != nullptr) && (mValueLength != 0));
1035 error = aAppender.AppendBytes(mValue, mValueLength);
1036 ExitNow();
1037 }
1038
1039 keyLength = StringLength(mKey, static_cast<uint16_t>(kMaxKeyValueEncodedSize) + 1);
1040
1041 VerifyOrExit(kMinKeyLength <= keyLength, error = kErrorInvalidArgs);
1042
1043 if (mValue == nullptr)
1044 {
1045 // Treat as a boolean attribute and encoded as "key" (with no `=`).
1046 SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength)));
1047 error = aAppender.AppendBytes(mKey, keyLength);
1048 ExitNow();
1049 }
1050
1051 // Treat as key/value and encode as "key=value", value may be empty.
1052
1053 VerifyOrExit(mValueLength + keyLength + sizeof(char) <= kMaxKeyValueEncodedSize, error = kErrorInvalidArgs);
1054
1055 SuccessOrExit(error = aAppender.Append<uint8_t>(static_cast<uint8_t>(keyLength + mValueLength + sizeof(char))));
1056 SuccessOrExit(error = aAppender.AppendBytes(mKey, keyLength));
1057 SuccessOrExit(error = aAppender.Append(separator));
1058 error = aAppender.AppendBytes(mValue, mValueLength);
1059
1060 exit:
1061 return error;
1062 }
1063
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Message & aMessage)1064 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Message &aMessage)
1065 {
1066 Appender appender(aMessage);
1067
1068 return AppendEntries(aEntries, aNumEntries, appender);
1069 }
1070
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,MutableData<kWithUint16Length> & aData)1071 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, MutableData<kWithUint16Length> &aData)
1072 {
1073 Error error;
1074 Appender appender(aData.GetBytes(), aData.GetLength());
1075
1076 SuccessOrExit(error = AppendEntries(aEntries, aNumEntries, appender));
1077 appender.GetAsData(aData);
1078
1079 exit:
1080 return error;
1081 }
1082
AppendEntries(const TxtEntry * aEntries,uint16_t aNumEntries,Appender & aAppender)1083 Error TxtEntry::AppendEntries(const TxtEntry *aEntries, uint16_t aNumEntries, Appender &aAppender)
1084 {
1085 Error error = kErrorNone;
1086
1087 for (uint16_t index = 0; index < aNumEntries; index++)
1088 {
1089 SuccessOrExit(error = aEntries[index].AppendTo(aAppender));
1090 }
1091
1092 if (aAppender.GetAppendedLength() == 0)
1093 {
1094 error = aAppender.Append<uint8_t>(0);
1095 }
1096
1097 exit:
1098 return error;
1099 }
1100
IsValid(void) const1101 bool AaaaRecord::IsValid(void) const
1102 {
1103 return GetType() == Dns::ResourceRecord::kTypeAaaa && GetSize() == sizeof(*this);
1104 }
1105
IsValid(void) const1106 bool KeyRecord::IsValid(void) const { return GetType() == Dns::ResourceRecord::kTypeKey; }
1107
1108 #if OPENTHREAD_CONFIG_SRP_SERVER_ENABLE
Init(void)1109 void Ecdsa256KeyRecord::Init(void)
1110 {
1111 KeyRecord::Init();
1112 SetAlgorithm(kAlgorithmEcdsaP256Sha256);
1113 }
1114
IsValid(void) const1115 bool Ecdsa256KeyRecord::IsValid(void) const
1116 {
1117 return KeyRecord::IsValid() && GetLength() == sizeof(*this) - sizeof(ResourceRecord) &&
1118 GetAlgorithm() == kAlgorithmEcdsaP256Sha256;
1119 }
1120 #endif
1121
IsValid(void) const1122 bool SigRecord::IsValid(void) const
1123 {
1124 return GetType() == Dns::ResourceRecord::kTypeSig && GetLength() >= sizeof(*this) - sizeof(ResourceRecord);
1125 }
1126
InitAsShortVariant(uint32_t aLeaseInterval)1127 void LeaseOption::InitAsShortVariant(uint32_t aLeaseInterval)
1128 {
1129 SetOptionCode(kUpdateLease);
1130 SetOptionLength(kShortLength);
1131 SetLeaseInterval(aLeaseInterval);
1132 }
1133
InitAsLongVariant(uint32_t aLeaseInterval,uint32_t aKeyLeaseInterval)1134 void LeaseOption::InitAsLongVariant(uint32_t aLeaseInterval, uint32_t aKeyLeaseInterval)
1135 {
1136 SetOptionCode(kUpdateLease);
1137 SetOptionLength(kLongLength);
1138 SetLeaseInterval(aLeaseInterval);
1139 SetKeyLeaseInterval(aKeyLeaseInterval);
1140 }
1141
IsValid(void) const1142 bool LeaseOption::IsValid(void) const
1143 {
1144 bool isValid = false;
1145
1146 VerifyOrExit((GetOptionLength() == kShortLength) || (GetOptionLength() >= kLongLength));
1147 isValid = (GetLeaseInterval() <= GetKeyLeaseInterval());
1148
1149 exit:
1150 return isValid;
1151 }
1152
ReadFrom(const Message & aMessage,uint16_t aOffset,uint16_t aLength)1153 Error LeaseOption::ReadFrom(const Message &aMessage, uint16_t aOffset, uint16_t aLength)
1154 {
1155 Error error = kErrorNone;
1156 uint16_t endOffset;
1157
1158 VerifyOrExit(static_cast<uint32_t>(aOffset) + aLength <= aMessage.GetLength(), error = kErrorParse);
1159
1160 endOffset = aOffset + aLength;
1161
1162 while (aOffset < endOffset)
1163 {
1164 uint16_t size;
1165
1166 SuccessOrExit(error = aMessage.Read(aOffset, this, sizeof(Option)));
1167
1168 VerifyOrExit(aOffset + GetSize() <= endOffset, error = kErrorParse);
1169
1170 size = static_cast<uint16_t>(GetSize());
1171
1172 if (GetOptionCode() == kUpdateLease)
1173 {
1174 VerifyOrExit(GetOptionLength() >= kShortLength, error = kErrorParse);
1175
1176 IgnoreError(aMessage.Read(aOffset, this, Min(size, static_cast<uint16_t>(sizeof(LeaseOption)))));
1177 VerifyOrExit(IsValid(), error = kErrorParse);
1178
1179 ExitNow();
1180 }
1181
1182 aOffset += size;
1183 }
1184
1185 error = kErrorNotFound;
1186
1187 exit:
1188 return error;
1189 }
1190
ReadPtrName(const Message & aMessage,uint16_t & aOffset,char * aLabelBuffer,uint8_t aLabelBufferSize,char * aNameBuffer,uint16_t aNameBufferSize) const1191 Error PtrRecord::ReadPtrName(const Message &aMessage,
1192 uint16_t &aOffset,
1193 char *aLabelBuffer,
1194 uint8_t aLabelBufferSize,
1195 char *aNameBuffer,
1196 uint16_t aNameBufferSize) const
1197 {
1198 Error error = kErrorNone;
1199 uint16_t startOffset = aOffset - sizeof(PtrRecord); // start of `PtrRecord`.
1200
1201 // Verify that the name is within the record data length.
1202 SuccessOrExit(error = Name::ParseName(aMessage, aOffset));
1203 VerifyOrExit(aOffset <= startOffset + GetSize(), error = kErrorParse);
1204
1205 aOffset = startOffset + sizeof(PtrRecord);
1206 SuccessOrExit(error = Name::ReadLabel(aMessage, aOffset, aLabelBuffer, aLabelBufferSize));
1207
1208 if (aNameBuffer != nullptr)
1209 {
1210 SuccessOrExit(error = Name::ReadName(aMessage, aOffset, aNameBuffer, aNameBufferSize));
1211 }
1212
1213 aOffset = startOffset;
1214 error = SkipRecord(aMessage, aOffset);
1215
1216 exit:
1217 return error;
1218 }
1219
ReadTxtData(const Message & aMessage,uint16_t & aOffset,uint8_t * aTxtBuffer,uint16_t & aTxtBufferSize) const1220 Error TxtRecord::ReadTxtData(const Message &aMessage,
1221 uint16_t &aOffset,
1222 uint8_t *aTxtBuffer,
1223 uint16_t &aTxtBufferSize) const
1224 {
1225 Error error = kErrorNone;
1226
1227 SuccessOrExit(error = aMessage.Read(aOffset, aTxtBuffer, Min(GetLength(), aTxtBufferSize)));
1228 aOffset += GetLength();
1229
1230 VerifyOrExit(GetLength() <= aTxtBufferSize, error = kErrorNoBufs);
1231 aTxtBufferSize = GetLength();
1232 VerifyOrExit(VerifyTxtData(aTxtBuffer, aTxtBufferSize, /* aAllowEmpty */ true), error = kErrorParse);
1233
1234 exit:
1235 return error;
1236 }
1237
VerifyTxtData(const uint8_t * aTxtData,uint16_t aTxtLength,bool aAllowEmpty)1238 bool TxtRecord::VerifyTxtData(const uint8_t *aTxtData, uint16_t aTxtLength, bool aAllowEmpty)
1239 {
1240 bool valid = false;
1241 uint8_t curEntryLength = 0;
1242
1243 // Per RFC 1035, TXT-DATA MUST have one or more <character-string>s.
1244 VerifyOrExit(aAllowEmpty || aTxtLength > 0);
1245
1246 for (uint16_t i = 0; i < aTxtLength; ++i)
1247 {
1248 if (curEntryLength == 0)
1249 {
1250 curEntryLength = aTxtData[i];
1251 }
1252 else
1253 {
1254 --curEntryLength;
1255 }
1256 }
1257
1258 valid = (curEntryLength == 0);
1259
1260 exit:
1261 return valid;
1262 }
1263
1264 } // namespace Dns
1265 } // namespace ot
1266