1 
2 /*
3  *  Copyright (c) 2018, The OpenThread Authors.
4  *  All rights reserved.
5  *
6  *  Redistribution and use in source and binary forms, with or without
7  *  modification, are permitted provided that the following conditions are met:
8  *  1. Redistributions of source code must retain the above copyright
9  *     notice, this list of conditions and the following disclaimer.
10  *  2. Redistributions in binary form must reproduce the above copyright
11  *     notice, this list of conditions and the following disclaimer in the
12  *     documentation and/or other materials provided with the distribution.
13  *  3. Neither the name of the copyright holder nor the
14  *     names of its contributors may be used to endorse or promote products
15  *     derived from this software without specific prior written permission.
16  *
17  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
21  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27  *  POSSIBILITY OF SUCH DAMAGE.
28  */
29 
30 #include "sntp_client.hpp"
31 
32 #if OPENTHREAD_CONFIG_SNTP_CLIENT_ENABLE
33 
34 #include "common/as_core_type.hpp"
35 #include "common/code_utils.hpp"
36 #include "common/debug.hpp"
37 #include "common/locator_getters.hpp"
38 #include "common/log.hpp"
39 #include "instance/instance.hpp"
40 #include "net/udp6.hpp"
41 #include "thread/thread_netif.hpp"
42 
43 /**
44  * @file
45  *   This file implements the SNTP client.
46  */
47 
48 namespace ot {
49 namespace Sntp {
50 
51 RegisterLogModule("SntpClnt");
52 
Client(Instance & aInstance)53 Client::Client(Instance &aInstance)
54     : mSocket(aInstance)
55     , mRetransmissionTimer(aInstance)
56     , mUnixEra(0)
57 {
58 }
59 
Start(void)60 Error Client::Start(void)
61 {
62     Error error;
63 
64     SuccessOrExit(error = mSocket.Open(&Client::HandleUdpReceive, this));
65     SuccessOrExit(error = mSocket.Bind(0, Ip6::kNetifUnspecified));
66 
67 exit:
68     return error;
69 }
70 
Stop(void)71 Error Client::Stop(void)
72 {
73     for (Message &message : mPendingQueries)
74     {
75         QueryMetadata queryMetadata;
76 
77         queryMetadata.ReadFrom(message);
78         FinalizeSntpTransaction(message, queryMetadata, 0, kErrorAbort);
79     }
80 
81     return mSocket.Close();
82 }
83 
Query(const otSntpQuery * aQuery,otSntpResponseHandler aHandler,void * aContext)84 Error Client::Query(const otSntpQuery *aQuery, otSntpResponseHandler aHandler, void *aContext)
85 {
86     Error                   error;
87     QueryMetadata           queryMetadata;
88     Message                *message     = nullptr;
89     Message                *messageCopy = nullptr;
90     Header                  header;
91     const Ip6::MessageInfo *messageInfo;
92 
93     VerifyOrExit(aQuery->mMessageInfo != nullptr, error = kErrorInvalidArgs);
94 
95     header.Init();
96 
97     // Originate timestamp is used only as a unique token.
98     header.SetTransmitTimestampSeconds(TimerMilli::GetNow().GetValue() / 1000 + kTimeAt1970);
99 
100     VerifyOrExit((message = NewMessage(header)) != nullptr, error = kErrorNoBufs);
101 
102     messageInfo = AsCoreTypePtr(aQuery->mMessageInfo);
103 
104     queryMetadata.mResponseHandler.Set(aHandler, aContext);
105     queryMetadata.mTransmitTimestamp   = header.GetTransmitTimestampSeconds();
106     queryMetadata.mTransmissionTime    = TimerMilli::GetNow() + kResponseTimeout;
107     queryMetadata.mSourceAddress       = messageInfo->GetSockAddr();
108     queryMetadata.mDestinationPort     = messageInfo->GetPeerPort();
109     queryMetadata.mDestinationAddress  = messageInfo->GetPeerAddr();
110     queryMetadata.mRetransmissionCount = 0;
111 
112     VerifyOrExit((messageCopy = CopyAndEnqueueMessage(*message, queryMetadata)) != nullptr, error = kErrorNoBufs);
113     SuccessOrExit(error = SendMessage(*message, *messageInfo));
114 
115 exit:
116 
117     if (error != kErrorNone)
118     {
119         if (message)
120         {
121             message->Free();
122         }
123 
124         if (messageCopy)
125         {
126             DequeueMessage(*messageCopy);
127         }
128     }
129 
130     return error;
131 }
132 
NewMessage(const Header & aHeader)133 Message *Client::NewMessage(const Header &aHeader)
134 {
135     Message *message = nullptr;
136 
137     VerifyOrExit((message = mSocket.NewMessage(sizeof(aHeader))) != nullptr);
138     IgnoreError(message->Prepend(aHeader));
139     message->SetOffset(0);
140 
141 exit:
142     return message;
143 }
144 
CopyAndEnqueueMessage(const Message & aMessage,const QueryMetadata & aQueryMetadata)145 Message *Client::CopyAndEnqueueMessage(const Message &aMessage, const QueryMetadata &aQueryMetadata)
146 {
147     Error    error       = kErrorNone;
148     Message *messageCopy = nullptr;
149 
150     // Create a message copy for further retransmissions.
151     VerifyOrExit((messageCopy = aMessage.Clone()) != nullptr, error = kErrorNoBufs);
152 
153     // Append the copy with retransmission data and add it to the queue.
154     SuccessOrExit(error = aQueryMetadata.AppendTo(*messageCopy));
155     mPendingQueries.Enqueue(*messageCopy);
156 
157     mRetransmissionTimer.FireAtIfEarlier(aQueryMetadata.mTransmissionTime);
158 
159 exit:
160     FreeAndNullMessageOnError(messageCopy, error);
161     return messageCopy;
162 }
163 
DequeueMessage(Message & aMessage)164 void Client::DequeueMessage(Message &aMessage)
165 {
166     if (mRetransmissionTimer.IsRunning() && (mPendingQueries.GetHead() == nullptr))
167     {
168         // No more requests pending, stop the timer.
169         mRetransmissionTimer.Stop();
170     }
171 
172     mPendingQueries.DequeueAndFree(aMessage);
173 }
174 
SendMessage(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)175 Error Client::SendMessage(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
176 {
177     return mSocket.SendTo(aMessage, aMessageInfo);
178 }
179 
SendCopy(const Message & aMessage,const Ip6::MessageInfo & aMessageInfo)180 void Client::SendCopy(const Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
181 {
182     Error    error;
183     Message *messageCopy = nullptr;
184 
185     // Create a message copy for lower layers.
186     VerifyOrExit((messageCopy = aMessage.Clone(aMessage.GetLength() - sizeof(QueryMetadata))) != nullptr,
187                  error = kErrorNoBufs);
188 
189     // Send the copy.
190     SuccessOrExit(error = SendMessage(*messageCopy, aMessageInfo));
191 
192 exit:
193     if (error != kErrorNone)
194     {
195         FreeMessage(messageCopy);
196         LogWarn("Failed to send SNTP request: %s", ErrorToString(error));
197     }
198 }
199 
FindRelatedQuery(const Header & aResponseHeader,QueryMetadata & aQueryMetadata)200 Message *Client::FindRelatedQuery(const Header &aResponseHeader, QueryMetadata &aQueryMetadata)
201 {
202     Message *matchedMessage = nullptr;
203 
204     for (Message &message : mPendingQueries)
205     {
206         // Read originate timestamp.
207         aQueryMetadata.ReadFrom(message);
208 
209         if (aQueryMetadata.mTransmitTimestamp == aResponseHeader.GetOriginateTimestampSeconds())
210         {
211             matchedMessage = &message;
212             break;
213         }
214     }
215 
216     return matchedMessage;
217 }
218 
FinalizeSntpTransaction(Message & aQuery,const QueryMetadata & aQueryMetadata,uint64_t aTime,Error aResult)219 void Client::FinalizeSntpTransaction(Message             &aQuery,
220                                      const QueryMetadata &aQueryMetadata,
221                                      uint64_t             aTime,
222                                      Error                aResult)
223 {
224     DequeueMessage(aQuery);
225     aQueryMetadata.mResponseHandler.InvokeIfSet(aTime, aResult);
226 }
227 
HandleRetransmissionTimer(void)228 void Client::HandleRetransmissionTimer(void)
229 {
230     TimeMilli        now      = TimerMilli::GetNow();
231     TimeMilli        nextTime = now.GetDistantFuture();
232     QueryMetadata    queryMetadata;
233     Ip6::MessageInfo messageInfo;
234 
235     for (Message &message : mPendingQueries)
236     {
237         queryMetadata.ReadFrom(message);
238 
239         if (now >= queryMetadata.mTransmissionTime)
240         {
241             if (queryMetadata.mRetransmissionCount >= kMaxRetransmit)
242             {
243                 // No expected response.
244                 FinalizeSntpTransaction(message, queryMetadata, 0, kErrorResponseTimeout);
245                 continue;
246             }
247 
248             // Increment retransmission counter and timer.
249             queryMetadata.mRetransmissionCount++;
250             queryMetadata.mTransmissionTime = now + kResponseTimeout;
251             queryMetadata.UpdateIn(message);
252 
253             // Retransmit
254             messageInfo.SetPeerAddr(queryMetadata.mDestinationAddress);
255             messageInfo.SetPeerPort(queryMetadata.mDestinationPort);
256             messageInfo.SetSockAddr(queryMetadata.mSourceAddress);
257 
258             SendCopy(message, messageInfo);
259         }
260 
261         nextTime = Min(nextTime, queryMetadata.mTransmissionTime);
262     }
263 
264     if (nextTime < now.GetDistantFuture())
265     {
266         mRetransmissionTimer.FireAt(nextTime);
267     }
268 }
269 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)270 void Client::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
271 {
272     static_cast<Client *>(aContext)->HandleUdpReceive(AsCoreType(aMessage), AsCoreType(aMessageInfo));
273 }
274 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)275 void Client::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
276 {
277     OT_UNUSED_VARIABLE(aMessageInfo);
278 
279     Error         error = kErrorNone;
280     Header        responseHeader;
281     QueryMetadata queryMetadata;
282     Message      *message  = nullptr;
283     uint64_t      unixTime = 0;
284 
285     SuccessOrExit(aMessage.Read(aMessage.GetOffset(), responseHeader));
286 
287     VerifyOrExit((message = FindRelatedQuery(responseHeader, queryMetadata)) != nullptr);
288 
289     // Check if response came from the server.
290     VerifyOrExit(responseHeader.GetMode() == Header::kModeServer, error = kErrorFailed);
291 
292     // Check the Kiss-o'-death packet.
293     if (!responseHeader.GetStratum())
294     {
295         char kissCode[Header::kKissCodeLength + 1];
296 
297         memcpy(kissCode, responseHeader.GetKissCode(), Header::kKissCodeLength);
298         kissCode[Header::kKissCodeLength] = 0;
299 
300         LogInfo("SNTP response contains the Kiss-o'-death packet with %s code", kissCode);
301         ExitNow(error = kErrorBusy);
302     }
303 
304     // Check if timestamp has been set.
305     VerifyOrExit(responseHeader.GetTransmitTimestampSeconds() != 0 &&
306                      responseHeader.GetTransmitTimestampFraction() != 0,
307                  error = kErrorFailed);
308 
309     // The NTP time starts at 1900 while the unix epoch starts at 1970.
310     // Due to NTP protocol limitation, this module stops working correctly after around year 2106, if
311     // unix era is not updated. This seems to be a reasonable limitation for now. Era number cannot be
312     // obtained using NTP protocol, and client of this module is responsible to set it properly.
313     unixTime = GetUnixEra() * (1ULL << 32);
314 
315     if (responseHeader.GetTransmitTimestampSeconds() > kTimeAt1970)
316     {
317         unixTime += static_cast<uint64_t>(responseHeader.GetTransmitTimestampSeconds()) - kTimeAt1970;
318     }
319     else
320     {
321         unixTime += static_cast<uint64_t>(responseHeader.GetTransmitTimestampSeconds()) + (1ULL << 32) - kTimeAt1970;
322     }
323 
324     // Return the time since 1970.
325     FinalizeSntpTransaction(*message, queryMetadata, unixTime, kErrorNone);
326 
327 exit:
328 
329     if (message != nullptr && error != kErrorNone)
330     {
331         FinalizeSntpTransaction(*message, queryMetadata, 0, error);
332     }
333 }
334 
335 } // namespace Sntp
336 } // namespace ot
337 
338 #endif // OPENTHREAD_CONFIG_SNTP_CLIENT_ENABLE
339