1 /*
2 * Copyright (c) 2020, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "common/encoding.hpp"
30 #include "common/message.hpp"
31 #include "common/numeric_limits.hpp"
32 #include "common/random.hpp"
33 #include "instance/instance.hpp"
34 #include "net/checksum.hpp"
35 #include "net/icmp6.hpp"
36 #include "net/ip4_types.hpp"
37 #include "net/udp6.hpp"
38
39 #include "test_platform.h"
40 #include "test_util.hpp"
41
42 namespace ot {
43
CalculateChecksum(const void * aBuffer,uint16_t aLength)44 uint16_t CalculateChecksum(const void *aBuffer, uint16_t aLength)
45 {
46 // Calculates checksum over a given buffer data. This implementation
47 // is inspired by the algorithm from RFC-1071.
48
49 uint32_t sum = 0;
50 const uint8_t *bytes = reinterpret_cast<const uint8_t *>(aBuffer);
51
52 while (aLength >= sizeof(uint16_t))
53 {
54 sum += BigEndian::ReadUint16(bytes);
55 bytes += sizeof(uint16_t);
56 aLength -= sizeof(uint16_t);
57 }
58
59 if (aLength > 0)
60 {
61 sum += (static_cast<uint32_t>(bytes[0]) << 8);
62 }
63
64 // Fold 32-bit sum to 16 bits.
65
66 while (sum >> 16)
67 {
68 sum = (sum & 0xffff) + (sum >> 16);
69 }
70
71 return static_cast<uint16_t>(sum & 0xffff);
72 }
73
CalculateChecksum(const Ip6::Address & aSource,const Ip6::Address & aDestination,uint8_t aIpProto,const Message & aMessage)74 uint16_t CalculateChecksum(const Ip6::Address &aSource,
75 const Ip6::Address &aDestination,
76 uint8_t aIpProto,
77 const Message &aMessage)
78 {
79 // This method calculates the checksum over an IPv6 message.
80 constexpr uint16_t kMaxPayload = 1024;
81
82 OT_TOOL_PACKED_BEGIN
83 struct PseudoHeader
84 {
85 Ip6::Address mSource;
86 Ip6::Address mDestination;
87 uint32_t mPayloadLength;
88 uint32_t mProtocol;
89 } OT_TOOL_PACKED_END;
90
91 OT_TOOL_PACKED_BEGIN
92 struct ChecksumData
93 {
94 PseudoHeader mPseudoHeader;
95 uint8_t mPayload[kMaxPayload];
96 } OT_TOOL_PACKED_END;
97
98 ChecksumData data;
99 uint16_t payloadLength;
100
101 payloadLength = aMessage.GetLength() - aMessage.GetOffset();
102
103 data.mPseudoHeader.mSource = aSource;
104 data.mPseudoHeader.mDestination = aDestination;
105 data.mPseudoHeader.mProtocol = BigEndian::HostSwap32(aIpProto);
106 data.mPseudoHeader.mPayloadLength = BigEndian::HostSwap32(payloadLength);
107
108 SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
109
110 return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
111 }
112
CalculateChecksum(const Ip4::Address & aSource,const Ip4::Address & aDestination,uint8_t aIpProto,const Message & aMessage)113 uint16_t CalculateChecksum(const Ip4::Address &aSource,
114 const Ip4::Address &aDestination,
115 uint8_t aIpProto,
116 const Message &aMessage)
117 {
118 // This method calculates the checksum over an IPv4 message.
119 constexpr uint16_t kMaxPayload = 1024;
120
121 OT_TOOL_PACKED_BEGIN
122 struct PseudoHeader
123 {
124 Ip4::Address mSource;
125 Ip4::Address mDestination;
126 uint16_t mPayloadLength;
127 uint16_t mProtocol;
128 } OT_TOOL_PACKED_END;
129
130 OT_TOOL_PACKED_BEGIN
131 struct ChecksumData
132 {
133 PseudoHeader mPseudoHeader;
134 uint8_t mPayload[kMaxPayload];
135 } OT_TOOL_PACKED_END;
136
137 ChecksumData data;
138 uint16_t payloadLength;
139
140 payloadLength = aMessage.GetLength() - aMessage.GetOffset();
141
142 data.mPseudoHeader.mSource = aSource;
143 data.mPseudoHeader.mDestination = aDestination;
144 data.mPseudoHeader.mProtocol = BigEndian::HostSwap16(aIpProto);
145 data.mPseudoHeader.mPayloadLength = BigEndian::HostSwap16(payloadLength);
146
147 SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
148
149 return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
150 }
151
CorruptMessage(Message & aMessage)152 void CorruptMessage(Message &aMessage)
153 {
154 // Change a random bit in the message.
155
156 uint16_t byteOffset;
157 uint8_t bitOffset;
158 uint8_t byte;
159
160 byteOffset = Random::NonCrypto::GetUint16InRange(0, aMessage.GetLength());
161
162 SuccessOrQuit(aMessage.Read(byteOffset, byte));
163
164 bitOffset = Random::NonCrypto::GetUint8InRange(0, kBitsPerByte);
165
166 byte ^= (1 << bitOffset);
167
168 aMessage.Write(byteOffset, byte);
169 }
170
TestUdpMessageChecksum(void)171 void TestUdpMessageChecksum(void)
172 {
173 constexpr uint16_t kMinSize = sizeof(Ip6::Udp::Header);
174 constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
175
176 const char *kSourceAddress = "fd00:1122:3344:5566:7788:99aa:bbcc:ddee";
177 const char *kDestAddress = "fd01:2345:6789:abcd:ef01:2345:6789:abcd";
178
179 Instance *instance = static_cast<Instance *>(testInitInstance());
180
181 VerifyOrQuit(instance != nullptr);
182
183 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
184 {
185 Message *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Udp::Header));
186 Ip6::Udp::Header udpHeader;
187 Ip6::MessageInfo messageInfo;
188
189 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
190 SuccessOrQuit(message->SetLength(size));
191
192 // Write UDP header with a random payload.
193
194 Random::NonCrypto::Fill(udpHeader);
195 udpHeader.SetChecksum(0);
196 message->Write(0, udpHeader);
197
198 if (size > sizeof(udpHeader))
199 {
200 uint8_t buffer[kMaxSize];
201 uint16_t payloadSize = size - sizeof(udpHeader);
202
203 Random::NonCrypto::FillBuffer(buffer, payloadSize);
204 message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
205 }
206
207 SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
208 SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
209
210 // Verify that the `Checksum::UpdateMessageChecksum` correctly
211 // updates the checksum field in the UDP header on the message.
212
213 Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp);
214
215 SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
216 VerifyOrQuit(udpHeader.GetChecksum() != 0);
217
218 // Verify that the calculated UDP checksum is valid.
219
220 VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp,
221 *message) == 0xffff);
222
223 // Verify that `Checksum::VerifyMessageChecksum()` accepts the
224 // message and its calculated checksum.
225
226 SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp));
227
228 // Corrupt the message and verify that checksum is no longer accepted.
229
230 CorruptMessage(*message);
231
232 VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp) != kErrorNone,
233 "Checksum passed on corrupted message");
234
235 message->Free();
236 }
237 }
238
TestIcmp6MessageChecksum(void)239 void TestIcmp6MessageChecksum(void)
240 {
241 constexpr uint16_t kMinSize = sizeof(Ip6::Icmp::Header);
242 constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
243
244 const char *kSourceAddress = "fd00:feef:dccd:baab:9889:7667:5444:3223";
245 const char *kDestAddress = "fd01:abab:beef:cafe:1234:5678:9abc:0";
246
247 Instance *instance = static_cast<Instance *>(testInitInstance());
248
249 VerifyOrQuit(instance != nullptr, "Null OpenThread instance\n");
250
251 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
252 {
253 Message *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Icmp::Header));
254 Ip6::Icmp::Header icmp6Header;
255 Ip6::MessageInfo messageInfo;
256
257 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
258 SuccessOrQuit(message->SetLength(size));
259
260 // Write ICMP6 header with a random payload.
261
262 Random::NonCrypto::Fill(icmp6Header);
263 icmp6Header.SetChecksum(0);
264 message->Write(0, icmp6Header);
265
266 if (size > sizeof(icmp6Header))
267 {
268 uint8_t buffer[kMaxSize];
269 uint16_t payloadSize = size - sizeof(icmp6Header);
270
271 Random::NonCrypto::FillBuffer(buffer, payloadSize);
272 message->WriteBytes(sizeof(icmp6Header), &buffer[0], payloadSize);
273 }
274
275 SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
276 SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
277
278 // Verify that the `Checksum::UpdateMessageChecksum` correctly
279 // updates the checksum field in the ICMP6 header on the message.
280
281 Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(),
282 Ip6::kProtoIcmp6);
283
284 SuccessOrQuit(message->Read(message->GetOffset(), icmp6Header));
285 VerifyOrQuit(icmp6Header.GetChecksum() != 0, "Failed to update checksum");
286
287 // Verify that the calculated ICMP6 checksum is valid.
288
289 VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoIcmp6,
290 *message) == 0xffff);
291
292 // Verify that `Checksum::VerifyMessageChecksum()` accepts the
293 // message and its calculated checksum.
294
295 SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6));
296
297 // Corrupt the message and verify that checksum is no longer accepted.
298
299 CorruptMessage(*message);
300
301 VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6) != kErrorNone,
302 "Checksum passed on corrupted message");
303
304 message->Free();
305 }
306 }
307
TestTcp4MessageChecksum(void)308 void TestTcp4MessageChecksum(void)
309 {
310 constexpr size_t kMinSize = sizeof(Ip4::Tcp::Header);
311 constexpr size_t kMaxSize = kBufferSize * 3 + 24;
312
313 const char *kSourceAddress = "12.34.56.78";
314 const char *kDestAddress = "87.65.43.21";
315
316 Ip4::Address sourceAddress;
317 Ip4::Address destAddress;
318
319 Instance *instance = static_cast<Instance *>(testInitInstance());
320
321 VerifyOrQuit(instance != nullptr);
322
323 SuccessOrQuit(sourceAddress.FromString(kSourceAddress));
324 SuccessOrQuit(destAddress.FromString(kDestAddress));
325
326 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
327 {
328 Message *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip4::Tcp::Header));
329 Ip4::Tcp::Header tcpHeader;
330
331 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
332 SuccessOrQuit(message->SetLength(size));
333
334 // Write TCP header with a random payload.
335
336 Random::NonCrypto::Fill(tcpHeader);
337 message->Write(0, tcpHeader);
338
339 if (size > sizeof(tcpHeader))
340 {
341 uint8_t buffer[kMaxSize];
342 uint16_t payloadSize = size - sizeof(tcpHeader);
343
344 Random::NonCrypto::FillBuffer(buffer, payloadSize);
345 message->WriteBytes(sizeof(tcpHeader), &buffer[0], payloadSize);
346 }
347
348 // Verify that the `Checksum::UpdateMessageChecksum` correctly
349 // updates the checksum field in the UDP header on the message.
350
351 Checksum::UpdateMessageChecksum(*message, sourceAddress, destAddress, Ip4::kProtoTcp);
352
353 SuccessOrQuit(message->Read(message->GetOffset(), tcpHeader));
354 VerifyOrQuit(tcpHeader.GetChecksum() != 0);
355
356 // Verify that the calculated UDP checksum is valid.
357
358 VerifyOrQuit(CalculateChecksum(sourceAddress, destAddress, Ip4::kProtoTcp, *message) == 0xffff);
359 message->Free();
360 }
361 }
362
TestUdp4MessageChecksum(void)363 void TestUdp4MessageChecksum(void)
364 {
365 constexpr uint16_t kMinSize = sizeof(Ip4::Udp::Header);
366 constexpr uint16_t kMaxSize = kBufferSize * 3 + 24;
367
368 const char *kSourceAddress = "12.34.56.78";
369 const char *kDestAddress = "87.65.43.21";
370
371 Ip4::Address sourceAddress;
372 Ip4::Address destAddress;
373
374 Instance *instance = static_cast<Instance *>(testInitInstance());
375
376 SuccessOrQuit(sourceAddress.FromString(kSourceAddress));
377 SuccessOrQuit(destAddress.FromString(kDestAddress));
378
379 VerifyOrQuit(instance != nullptr);
380
381 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
382 {
383 Message *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip4::Udp::Header));
384 Ip4::Udp::Header udpHeader;
385
386 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
387 SuccessOrQuit(message->SetLength(size));
388
389 // Write UDP header with a random payload.
390
391 Random::NonCrypto::Fill(udpHeader);
392 udpHeader.SetChecksum(0);
393 message->Write(0, udpHeader);
394
395 if (size > sizeof(udpHeader))
396 {
397 uint8_t buffer[kMaxSize];
398 uint16_t payloadSize = size - sizeof(udpHeader);
399
400 Random::NonCrypto::FillBuffer(buffer, payloadSize);
401 message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
402 }
403
404 // Verify that the `Checksum::UpdateMessageChecksum` correctly
405 // updates the checksum field in the UDP header on the message.
406
407 Checksum::UpdateMessageChecksum(*message, sourceAddress, destAddress, Ip4::kProtoUdp);
408
409 SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
410 VerifyOrQuit(udpHeader.GetChecksum() != 0);
411
412 // Verify that the calculated UDP checksum is valid.
413
414 VerifyOrQuit(CalculateChecksum(sourceAddress, destAddress, Ip4::kProtoUdp, *message) == 0xffff);
415 message->Free();
416 }
417 }
418
TestIcmp4MessageChecksum(void)419 void TestIcmp4MessageChecksum(void)
420 {
421 // A captured ICMP echo request (ping) message. Checksum field is set to zero.
422 const uint8_t kExampleIcmpMessage[] = "\x08\x00\x00\x00\x67\x2e\x00\x00\x62\xaf\xf1\x61\x00\x04\xfc\x24"
423 "\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17"
424 "\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24\x25\x26\x27"
425 "\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36\x37";
426 uint16_t kChecksumForExampleMessage = 0x5594;
427 Instance *instance = static_cast<Instance *>(testInitInstance());
428 Message *message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(kExampleIcmpMessage));
429
430 Ip4::Address source;
431 Ip4::Address dest;
432
433 uint8_t mPayload[sizeof(kExampleIcmpMessage)];
434 Ip4::Icmp::Header icmpHeader;
435
436 SuccessOrQuit(message->AppendBytes(kExampleIcmpMessage, sizeof(kExampleIcmpMessage)));
437
438 // Random IPv4 address, ICMP message checksum does not include a presudo header like TCP and UDP.
439 source.mFields.m32 = 0x12345678;
440 dest.mFields.m32 = 0x87654321;
441
442 Checksum::UpdateMessageChecksum(*message, source, dest, Ip4::kProtoIcmp);
443
444 SuccessOrQuit(message->Read(0, icmpHeader));
445 VerifyOrQuit(icmpHeader.GetChecksum() == kChecksumForExampleMessage);
446
447 SuccessOrQuit(message->Read(message->GetOffset(), mPayload, sizeof(mPayload)));
448 VerifyOrQuit(CalculateChecksum(mPayload, sizeof(mPayload)) == 0xffff);
449 }
450
451 class ChecksumTester
452 {
453 public:
TestExampleVector(void)454 static void TestExampleVector(void)
455 {
456 // Example from RFC 1071
457 const uint8_t kTestVector[] = {0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7};
458 const uint16_t kTestVectorChecksum = 0xddf2;
459
460 Checksum checksum;
461
462 VerifyOrQuit(checksum.GetValue() == 0, "Incorrect initial checksum value");
463
464 checksum.AddData(kTestVector, sizeof(kTestVector));
465 VerifyOrQuit(checksum.GetValue() == kTestVectorChecksum);
466 VerifyOrQuit(checksum.GetValue() == CalculateChecksum(kTestVector, sizeof(kTestVector)), );
467 }
468 };
469
470 } // namespace ot
471
main(void)472 int main(void)
473 {
474 ot::ChecksumTester::TestExampleVector();
475 ot::TestUdpMessageChecksum();
476 ot::TestIcmp6MessageChecksum();
477 ot::TestTcp4MessageChecksum();
478 ot::TestUdp4MessageChecksum();
479 ot::TestIcmp4MessageChecksum();
480 printf("All tests passed\n");
481 return 0;
482 }
483