1 /*
2 * Copyright (c) 2020, The OpenThread Authors.
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * 1. Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * 2. Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * 3. Neither the name of the copyright holder nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #include "common/encoding.hpp"
30 #include "common/instance.hpp"
31 #include "common/message.hpp"
32 #include "common/random.hpp"
33 #include "net/checksum.hpp"
34 #include "net/icmp6.hpp"
35 #include "net/udp6.hpp"
36
37 #include "test_platform.h"
38 #include "test_util.hpp"
39
40 namespace ot {
41
CalculateChecksum(const void * aBuffer,uint16_t aLength)42 uint16_t CalculateChecksum(const void *aBuffer, uint16_t aLength)
43 {
44 // Calculates checksum over a given buffer data. This implementation
45 // is inspired by the algorithm from RFC-1071.
46
47 uint32_t sum = 0;
48 const uint8_t *bytes = reinterpret_cast<const uint8_t *>(aBuffer);
49
50 while (aLength >= sizeof(uint16_t))
51 {
52 sum += Encoding::BigEndian::ReadUint16(bytes);
53 bytes += sizeof(uint16_t);
54 aLength -= sizeof(uint16_t);
55 }
56
57 if (aLength > 0)
58 {
59 sum += (static_cast<uint32_t>(bytes[0]) << 8);
60 }
61
62 // Fold 32-bit sum to 16 bits.
63
64 while (sum >> 16)
65 {
66 sum = (sum & 0xffff) + (sum >> 16);
67 }
68
69 return static_cast<uint16_t>(sum & 0xffff);
70 }
71
CalculateChecksum(const Ip6::Address & aSource,const Ip6::Address & aDestination,uint8_t aIpProto,const Message & aMessage)72 uint16_t CalculateChecksum(const Ip6::Address &aSource,
73 const Ip6::Address &aDestination,
74 uint8_t aIpProto,
75 const Message & aMessage)
76 {
77 // This method calculates the checksum over an IPv6 message.
78
79 enum : uint16_t
80 {
81 kMaxPayload = 1024,
82 };
83
84 OT_TOOL_PACKED_BEGIN
85 struct PseudoHeader
86 {
87 Ip6::Address mSource;
88 Ip6::Address mDestination;
89 uint32_t mPayloadLength;
90 uint32_t mProtocol;
91 } OT_TOOL_PACKED_END;
92
93 OT_TOOL_PACKED_BEGIN
94 struct ChecksumData
95 {
96 PseudoHeader mPseudoHeader;
97 uint8_t mPayload[kMaxPayload];
98 } OT_TOOL_PACKED_END;
99
100 ChecksumData data;
101 uint16_t payloadLength;
102
103 payloadLength = aMessage.GetLength() - aMessage.GetOffset();
104
105 data.mPseudoHeader.mSource = aSource;
106 data.mPseudoHeader.mDestination = aDestination;
107 data.mPseudoHeader.mProtocol = Encoding::BigEndian::HostSwap32(aIpProto);
108 data.mPseudoHeader.mPayloadLength = Encoding::BigEndian::HostSwap32(payloadLength);
109
110 SuccessOrQuit(aMessage.Read(aMessage.GetOffset(), data.mPayload, payloadLength));
111
112 return CalculateChecksum(&data, sizeof(PseudoHeader) + payloadLength);
113 }
114
CorruptMessage(Message & aMessage)115 void CorruptMessage(Message &aMessage)
116 {
117 // Change a random bit in the message.
118
119 uint16_t byteOffset;
120 uint8_t bitOffset;
121 uint8_t byte;
122
123 byteOffset = Random::NonCrypto::GetUint16InRange(0, aMessage.GetLength());
124
125 SuccessOrQuit(aMessage.Read(byteOffset, byte));
126
127 bitOffset = Random::NonCrypto::GetUint8InRange(0, CHAR_BIT);
128
129 byte ^= (1 << bitOffset);
130
131 aMessage.Write(byteOffset, byte);
132 }
133
TestUdpMessageChecksum(void)134 void TestUdpMessageChecksum(void)
135 {
136 enum : uint16_t
137 {
138 kMinSize = sizeof(Ip6::Udp::Header),
139 kMaxSize = kBufferSize * 3 + 24,
140 };
141
142 const char *kSourceAddress = "fd00:1122:3344:5566:7788:99aa:bbcc:ddee";
143 const char *kDestAddress = "fd01:2345:6789:abcd:ef01:2345:6789:abcd";
144
145 Instance *instance = static_cast<Instance *>(testInitInstance());
146
147 VerifyOrQuit(instance != nullptr);
148
149 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
150 {
151 Message * message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Udp::Header));
152 Ip6::Udp::Header udpHeader;
153 Ip6::MessageInfo messageInfo;
154
155 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
156 SuccessOrQuit(message->SetLength(size));
157
158 // Write UDP header with a random payload.
159
160 Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&udpHeader), sizeof(udpHeader));
161 udpHeader.SetChecksum(0);
162 message->Write(0, udpHeader);
163
164 if (size > sizeof(udpHeader))
165 {
166 uint8_t buffer[kMaxSize];
167 uint16_t payloadSize = size - sizeof(udpHeader);
168
169 Random::NonCrypto::FillBuffer(buffer, payloadSize);
170 message->WriteBytes(sizeof(udpHeader), &buffer[0], payloadSize);
171 }
172
173 SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
174 SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
175
176 // Verify that the `Checksum::UpdateMessageChecksum` correctly
177 // updates the checksum field in the UDP header on the message.
178
179 Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp);
180
181 SuccessOrQuit(message->Read(message->GetOffset(), udpHeader));
182 VerifyOrQuit(udpHeader.GetChecksum() != 0);
183
184 // Verify that the calculated UDP checksum is valid.
185
186 VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoUdp,
187 *message) == 0xffff);
188
189 // Verify that `Checksum::VerifyMessageChecksum()` accepts the
190 // message and its calculated checksum.
191
192 SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp));
193
194 // Corrupt the message and verify that checksum is no longer accepted.
195
196 CorruptMessage(*message);
197
198 VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoUdp) != kErrorNone,
199 "Checksum passed on corrupted message");
200
201 message->Free();
202 }
203 }
204
TestIcmp6MessageChecksum(void)205 void TestIcmp6MessageChecksum(void)
206 {
207 enum : uint16_t
208 {
209 kMinSize = sizeof(Ip6::Icmp::Header),
210 kMaxSize = kBufferSize * 3 + 24,
211 };
212
213 const char *kSourceAddress = "fd00:feef:dccd:baab:9889:7667:5444:3223";
214 const char *kDestAddress = "fd01:abab:beef:cafe:1234:5678:9abc:0";
215
216 Instance *instance = static_cast<Instance *>(testInitInstance());
217
218 VerifyOrQuit(instance != nullptr, "Null OpenThread instance\n");
219
220 for (uint16_t size = kMinSize; size <= kMaxSize; size++)
221 {
222 Message * message = instance->Get<Ip6::Ip6>().NewMessage(sizeof(Ip6::Icmp::Header));
223 Ip6::Icmp::Header icmp6Header;
224 Ip6::MessageInfo messageInfo;
225
226 VerifyOrQuit(message != nullptr, "Ip6::NewMesssage() failed");
227 SuccessOrQuit(message->SetLength(size));
228
229 // Write ICMP6 header with a random payload.
230
231 Random::NonCrypto::FillBuffer(reinterpret_cast<uint8_t *>(&icmp6Header), sizeof(icmp6Header));
232 icmp6Header.SetChecksum(0);
233 message->Write(0, icmp6Header);
234
235 if (size > sizeof(icmp6Header))
236 {
237 uint8_t buffer[kMaxSize];
238 uint16_t payloadSize = size - sizeof(icmp6Header);
239
240 Random::NonCrypto::FillBuffer(buffer, payloadSize);
241 message->WriteBytes(sizeof(icmp6Header), &buffer[0], payloadSize);
242 }
243
244 SuccessOrQuit(messageInfo.GetSockAddr().FromString(kSourceAddress));
245 SuccessOrQuit(messageInfo.GetPeerAddr().FromString(kDestAddress));
246
247 // Verify that the `Checksum::UpdateMessageChecksum` correctly
248 // updates the checksum field in the ICMP6 header on the message.
249
250 Checksum::UpdateMessageChecksum(*message, messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(),
251 Ip6::kProtoIcmp6);
252
253 SuccessOrQuit(message->Read(message->GetOffset(), icmp6Header));
254 VerifyOrQuit(icmp6Header.GetChecksum() != 0, "Failed to update checksum");
255
256 // Verify that the calculated ICMP6 checksum is valid.
257
258 VerifyOrQuit(CalculateChecksum(messageInfo.GetSockAddr(), messageInfo.GetPeerAddr(), Ip6::kProtoIcmp6,
259 *message) == 0xffff);
260
261 // Verify that `Checksum::VerifyMessageChecksum()` accepts the
262 // message and its calculated checksum.
263
264 SuccessOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6));
265
266 // Corrupt the message and verify that checksum is no longer accepted.
267
268 CorruptMessage(*message);
269
270 VerifyOrQuit(Checksum::VerifyMessageChecksum(*message, messageInfo, Ip6::kProtoIcmp6) != kErrorNone,
271 "Checksum passed on corrupted message");
272
273 message->Free();
274 }
275 }
276
277 class ChecksumTester
278 {
279 public:
TestExampleVector(void)280 static void TestExampleVector(void)
281 {
282 // Example from RFC 1071
283 const uint8_t kTestVector[] = {0x00, 0x01, 0xf2, 0x03, 0xf4, 0xf5, 0xf6, 0xf7};
284 const uint16_t kTestVectorChecksum = 0xddf2;
285
286 Checksum checksum;
287
288 VerifyOrQuit(checksum.GetValue() == 0, "Incorrect initial checksum value");
289
290 checksum.AddData(kTestVector, sizeof(kTestVector));
291 VerifyOrQuit(checksum.GetValue() == kTestVectorChecksum);
292 VerifyOrQuit(checksum.GetValue() == CalculateChecksum(kTestVector, sizeof(kTestVector)), );
293 }
294 };
295
296 } // namespace ot
297
main(void)298 int main(void)
299 {
300 ot::ChecksumTester::TestExampleVector();
301 ot::TestUdpMessageChecksum();
302 ot::TestIcmp6MessageChecksum();
303 printf("All tests passed\n");
304 return 0;
305 }
306