1 /*
2 * Copyright (c) 2016, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 /**
30 * @file
31 * This file implements common methods for manipulating MLE TLVs.
32 */
33
34 #include "tlvs.hpp"
35
36 #include "common/code_utils.hpp"
37 #include "common/debug.hpp"
38 #include "common/message.hpp"
39
40 namespace ot {
41
GetSize(void) const42 uint32_t Tlv::GetSize(void) const
43 {
44 return IsExtended() ? sizeof(ExtendedTlv) + As<ExtendedTlv>(this)->GetLength() : sizeof(Tlv) + GetLength();
45 }
46
GetValue(void)47 uint8_t *Tlv::GetValue(void)
48 {
49 return reinterpret_cast<uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
50 }
51
GetValue(void) const52 const uint8_t *Tlv::GetValue(void) const
53 {
54 return reinterpret_cast<const uint8_t *>(this) + (IsExtended() ? sizeof(ExtendedTlv) : sizeof(Tlv));
55 }
56
AppendTo(Message & aMessage) const57 Error Tlv::AppendTo(Message &aMessage) const { return aMessage.AppendBytes(this, static_cast<uint16_t>(GetSize())); }
58
FindTlv(const Message & aMessage,uint8_t aType,uint16_t aMaxSize,Tlv & aTlv)59 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, uint16_t aMaxSize, Tlv &aTlv)
60 {
61 Error error;
62 ParsedInfo info;
63
64 SuccessOrExit(error = info.FindIn(aMessage, aType));
65
66 if (aMaxSize > info.mSize)
67 {
68 aMaxSize = info.mSize;
69 }
70
71 aMessage.ReadBytes(info.mOffset, &aTlv, aMaxSize);
72
73 exit:
74 return error;
75 }
76
FindTlvOffset(const Message & aMessage,uint8_t aType,uint16_t & aOffset)77 Error Tlv::FindTlvOffset(const Message &aMessage, uint8_t aType, uint16_t &aOffset)
78 {
79 Error error;
80 ParsedInfo info;
81
82 SuccessOrExit(error = info.FindIn(aMessage, aType));
83 aOffset = info.mOffset;
84
85 exit:
86 return error;
87 }
88
FindTlvValueOffset(const Message & aMessage,uint8_t aType,uint16_t & aValueOffset,uint16_t & aLength)89 Error Tlv::FindTlvValueOffset(const Message &aMessage, uint8_t aType, uint16_t &aValueOffset, uint16_t &aLength)
90 {
91 Error error;
92 ParsedInfo info;
93
94 SuccessOrExit(error = info.FindIn(aMessage, aType));
95
96 aValueOffset = info.mValueOffset;
97 aLength = info.mLength;
98
99 exit:
100 return error;
101 }
102
ParseFrom(const Message & aMessage,uint16_t aOffset)103 Error Tlv::ParsedInfo::ParseFrom(const Message &aMessage, uint16_t aOffset)
104 {
105 // This method reads and parses the TLV info from `aMessage` at
106 // `aOffset`. This can be used independent of whether the TLV is
107 // extended or not. It validates that the entire TLV can be read
108 // from `aMessage`. Returns `kErrorNone` when successfully parsed,
109 // otherwise `kErrorParse`.
110
111 Error error;
112 Tlv tlv;
113 ExtendedTlv extTlv;
114 uint16_t headerSize;
115
116 SuccessOrExit(error = aMessage.Read(aOffset, tlv));
117
118 if (!tlv.IsExtended())
119 {
120 mType = tlv.GetType();
121 mLength = tlv.GetLength();
122 headerSize = sizeof(Tlv);
123 }
124 else
125 {
126 SuccessOrExit(error = aMessage.Read(aOffset, extTlv));
127
128 mType = extTlv.GetType();
129 mLength = extTlv.GetLength();
130 headerSize = sizeof(ExtendedTlv);
131 }
132
133 // We know that we could successfully read `tlv` or `extTlv`
134 // (`headerSize` bytes) from the message, so the calculation of the
135 // remaining length as `aMessage.GetLength() - aOffset - headerSize`
136 // cannot underflow.
137
138 VerifyOrExit(mLength <= aMessage.GetLength() - aOffset - headerSize, error = kErrorParse);
139
140 // Now that we know the entire TLV is contained within the
141 // `aMessage`, we can safely calculate `mValueOffset` and `mSize`
142 // as `uint16_t` and know that there will be no overflow.
143
144 mType = tlv.GetType();
145 mOffset = aOffset;
146 mValueOffset = aOffset + headerSize;
147 mSize = mLength + headerSize;
148
149 exit:
150 return error;
151 }
152
FindIn(const Message & aMessage,uint8_t aType)153 Error Tlv::ParsedInfo::FindIn(const Message &aMessage, uint8_t aType)
154 {
155 // This method searches within `aMessage` starting from
156 // `aMessage.GetOffset()` for a TLV type `aType` and parsed its
157 // info and validates that the entire TLV can be read from
158 // `aMessage`. Returns `kErrorNone` when found, otherwise
159 // `kErrorNotFound`.
160
161 Error error = kErrorNotFound;
162 uint16_t offset = aMessage.GetOffset();
163
164 while (true)
165 {
166 SuccessOrExit(ParseFrom(aMessage, offset));
167
168 if (mType == aType)
169 {
170 error = kErrorNone;
171 ExitNow();
172 }
173
174 // `ParseFrom()` already validated that `offset + mSize` is
175 // less than `aMessage.GetLength()` and therefore we can not
176 // have an overflow here.
177
178 offset += mSize;
179 }
180
181 exit:
182 return error;
183 }
184
ReadStringTlv(const Message & aMessage,uint16_t aOffset,uint8_t aMaxStringLength,char * aValue)185 Error Tlv::ReadStringTlv(const Message &aMessage, uint16_t aOffset, uint8_t aMaxStringLength, char *aValue)
186 {
187 Error error = kErrorNone;
188 ParsedInfo info;
189 uint16_t length;
190
191 SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
192
193 length = Min(info.mLength, static_cast<uint16_t>(aMaxStringLength));
194
195 aMessage.ReadBytes(info.mValueOffset, aValue, length);
196 aValue[length] = '\0';
197
198 exit:
199 return error;
200 }
201
ReadUintTlv(const Message & aMessage,uint16_t aOffset,UintType & aValue)202 template <typename UintType> Error Tlv::ReadUintTlv(const Message &aMessage, uint16_t aOffset, UintType &aValue)
203 {
204 Error error;
205
206 SuccessOrExit(error = ReadTlvValue(aMessage, aOffset, &aValue, sizeof(aValue)));
207 aValue = Encoding::BigEndian::HostSwap<UintType>(aValue);
208
209 exit:
210 return error;
211 }
212
213 // Explicit instantiations of `ReadUintTlv<>()`
214 template Error Tlv::ReadUintTlv<uint8_t>(const Message &aMessage, uint16_t aOffset, uint8_t &aValue);
215 template Error Tlv::ReadUintTlv<uint16_t>(const Message &aMessage, uint16_t aOffset, uint16_t &aValue);
216 template Error Tlv::ReadUintTlv<uint32_t>(const Message &aMessage, uint16_t aOffset, uint32_t &aValue);
217
ReadTlvValue(const Message & aMessage,uint16_t aOffset,void * aValue,uint8_t aMinLength)218 Error Tlv::ReadTlvValue(const Message &aMessage, uint16_t aOffset, void *aValue, uint8_t aMinLength)
219 {
220 Error error;
221 ParsedInfo info;
222
223 SuccessOrExit(error = info.ParseFrom(aMessage, aOffset));
224
225 VerifyOrExit(info.mLength >= aMinLength, error = kErrorParse);
226
227 aMessage.ReadBytes(info.mValueOffset, aValue, aMinLength);
228
229 exit:
230 return error;
231 }
232
FindStringTlv(const Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,char * aValue)233 Error Tlv::FindStringTlv(const Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, char *aValue)
234 {
235 Error error = kErrorNone;
236 uint16_t offset;
237
238 SuccessOrExit(error = FindTlvOffset(aMessage, aType, offset));
239 error = ReadStringTlv(aMessage, offset, aMaxStringLength, aValue);
240
241 exit:
242 return error;
243 }
244
FindUintTlv(const Message & aMessage,uint8_t aType,UintType & aValue)245 template <typename UintType> Error Tlv::FindUintTlv(const Message &aMessage, uint8_t aType, UintType &aValue)
246 {
247 Error error = kErrorNone;
248 uint16_t offset;
249
250 SuccessOrExit(error = FindTlvOffset(aMessage, aType, offset));
251 error = ReadUintTlv<UintType>(aMessage, offset, aValue);
252
253 exit:
254 return error;
255 }
256
257 // Explicit instantiations of `FindUintTlv<>()`
258 template Error Tlv::FindUintTlv<uint8_t>(const Message &aMessage, uint8_t aType, uint8_t &aValue);
259 template Error Tlv::FindUintTlv<uint16_t>(const Message &aMessage, uint8_t aType, uint16_t &aValue);
260 template Error Tlv::FindUintTlv<uint32_t>(const Message &aMessage, uint8_t aType, uint32_t &aValue);
261
FindTlv(const Message & aMessage,uint8_t aType,void * aValue,uint8_t aLength)262 Error Tlv::FindTlv(const Message &aMessage, uint8_t aType, void *aValue, uint8_t aLength)
263 {
264 Error error;
265 uint16_t offset;
266 uint16_t length;
267
268 SuccessOrExit(error = FindTlvValueOffset(aMessage, aType, offset, length));
269 VerifyOrExit(length >= aLength, error = kErrorParse);
270 aMessage.ReadBytes(offset, aValue, aLength);
271
272 exit:
273 return error;
274 }
275
AppendStringTlv(Message & aMessage,uint8_t aType,uint8_t aMaxStringLength,const char * aValue)276 Error Tlv::AppendStringTlv(Message &aMessage, uint8_t aType, uint8_t aMaxStringLength, const char *aValue)
277 {
278 uint16_t length = (aValue == nullptr) ? 0 : StringLength(aValue, aMaxStringLength);
279
280 return AppendTlv(aMessage, aType, aValue, static_cast<uint8_t>(length));
281 }
282
AppendUintTlv(Message & aMessage,uint8_t aType,UintType aValue)283 template <typename UintType> Error Tlv::AppendUintTlv(Message &aMessage, uint8_t aType, UintType aValue)
284 {
285 UintType value = Encoding::BigEndian::HostSwap<UintType>(aValue);
286
287 return AppendTlv(aMessage, aType, &value, sizeof(UintType));
288 }
289
290 // Explicit instantiations of `AppendUintTlv<>()`
291 template Error Tlv::AppendUintTlv<uint8_t>(Message &aMessage, uint8_t aType, uint8_t aValue);
292 template Error Tlv::AppendUintTlv<uint16_t>(Message &aMessage, uint8_t aType, uint16_t aValue);
293 template Error Tlv::AppendUintTlv<uint32_t>(Message &aMessage, uint8_t aType, uint32_t aValue);
294
AppendTlv(Message & aMessage,uint8_t aType,const void * aValue,uint8_t aLength)295 Error Tlv::AppendTlv(Message &aMessage, uint8_t aType, const void *aValue, uint8_t aLength)
296 {
297 Error error = kErrorNone;
298 Tlv tlv;
299
300 OT_ASSERT(aLength <= Tlv::kBaseTlvMaxLength);
301
302 tlv.SetType(aType);
303 tlv.SetLength(aLength);
304 SuccessOrExit(error = aMessage.Append(tlv));
305
306 VerifyOrExit(aLength > 0);
307 error = aMessage.AppendBytes(aValue, aLength);
308
309 exit:
310 return error;
311 }
312
313 } // namespace ot
314