1 /*
2  *  Copyright (c) 2016, The OpenThread Authors.
3  *  All rights reserved.
4  *
5  *  Redistribution and use in source and binary forms, with or without
6  *  modification, are permitted provided that the following conditions are met:
7  *  1. Redistributions of source code must retain the above copyright
8  *     notice, this list of conditions and the following disclaimer.
9  *  2. Redistributions in binary form must reproduce the above copyright
10  *     notice, this list of conditions and the following disclaimer in the
11  *     documentation and/or other materials provided with the distribution.
12  *  3. Neither the name of the copyright holder nor the
13  *     names of its contributors may be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17  *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18  *  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19  *  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20  *  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21  *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22  *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23  *  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24  *  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25  *  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26  *  POSSIBILITY OF SUCH DAMAGE.
27  */
28 
29 /**
30  * @file
31  *   This file implements the Joiner Router role.
32  */
33 
34 #include "joiner_router.hpp"
35 
36 #if OPENTHREAD_FTD
37 
38 #include <stdio.h>
39 
40 #include "common/code_utils.hpp"
41 #include "common/encoding.hpp"
42 #include "common/instance.hpp"
43 #include "common/locator_getters.hpp"
44 #include "common/logging.hpp"
45 #include "meshcop/meshcop.hpp"
46 #include "meshcop/meshcop_tlvs.hpp"
47 #include "thread/mle.hpp"
48 #include "thread/thread_netif.hpp"
49 #include "thread/uri_paths.hpp"
50 
51 namespace ot {
52 namespace MeshCoP {
53 
JoinerRouter(Instance & aInstance)54 JoinerRouter::JoinerRouter(Instance &aInstance)
55     : InstanceLocator(aInstance)
56     , mSocket(aInstance)
57     , mRelayTransmit(UriPath::kRelayTx, &JoinerRouter::HandleRelayTransmit, this)
58     , mTimer(aInstance, JoinerRouter::HandleTimer)
59     , mJoinerUdpPort(0)
60     , mIsJoinerPortConfigured(false)
61 {
62     Get<Tmf::Agent>().AddResource(mRelayTransmit);
63 }
64 
HandleNotifierEvents(Events aEvents)65 void JoinerRouter::HandleNotifierEvents(Events aEvents)
66 {
67     if (aEvents.Contains(kEventThreadNetdataChanged))
68     {
69         Start();
70     }
71 }
72 
Start(void)73 void JoinerRouter::Start(void)
74 {
75     VerifyOrExit(Get<Mle::MleRouter>().IsFullThreadDevice());
76 
77     if (Get<NetworkData::Leader>().IsJoiningEnabled())
78     {
79         uint16_t port = GetJoinerUdpPort();
80 
81         VerifyOrExit(!mSocket.IsBound());
82 
83         IgnoreError(mSocket.Open(&JoinerRouter::HandleUdpReceive, this));
84         IgnoreError(mSocket.Bind(port));
85         IgnoreError(Get<Ip6::Filter>().AddUnsecurePort(port));
86         otLogInfoMeshCoP("Joiner Router: start");
87     }
88     else
89     {
90         VerifyOrExit(mSocket.IsBound());
91 
92         IgnoreError(Get<Ip6::Filter>().RemoveUnsecurePort(mSocket.GetSockName().mPort));
93 
94         IgnoreError(mSocket.Close());
95     }
96 
97 exit:
98     return;
99 }
100 
GetJoinerUdpPort(void)101 uint16_t JoinerRouter::GetJoinerUdpPort(void)
102 {
103     uint16_t                rval = OPENTHREAD_CONFIG_JOINER_UDP_PORT;
104     const JoinerUdpPortTlv *joinerUdpPort;
105 
106     VerifyOrExit(!mIsJoinerPortConfigured, rval = mJoinerUdpPort);
107 
108     joinerUdpPort = static_cast<const JoinerUdpPortTlv *>(
109         Get<NetworkData::Leader>().GetCommissioningDataSubTlv(Tlv::kJoinerUdpPort));
110     VerifyOrExit(joinerUdpPort != nullptr);
111 
112     rval = joinerUdpPort->GetUdpPort();
113 
114 exit:
115     return rval;
116 }
117 
SetJoinerUdpPort(uint16_t aJoinerUdpPort)118 void JoinerRouter::SetJoinerUdpPort(uint16_t aJoinerUdpPort)
119 {
120     mJoinerUdpPort          = aJoinerUdpPort;
121     mIsJoinerPortConfigured = true;
122     Start();
123 }
124 
HandleUdpReceive(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)125 void JoinerRouter::HandleUdpReceive(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
126 {
127     static_cast<JoinerRouter *>(aContext)->HandleUdpReceive(*static_cast<Message *>(aMessage),
128                                                             *static_cast<const Ip6::MessageInfo *>(aMessageInfo));
129 }
130 
HandleUdpReceive(Message & aMessage,const Ip6::MessageInfo & aMessageInfo)131 void JoinerRouter::HandleUdpReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
132 {
133     Error            error;
134     Coap::Message *  message = nullptr;
135     Ip6::MessageInfo messageInfo;
136     ExtendedTlv      tlv;
137     uint16_t         borderAgentRloc;
138     uint16_t         offset;
139 
140     otLogInfoMeshCoP("JoinerRouter::HandleUdpReceive");
141 
142     SuccessOrExit(error = GetBorderAgentRloc(Get<ThreadNetif>(), borderAgentRloc));
143 
144     VerifyOrExit((message = NewMeshCoPMessage(Get<Tmf::Agent>())) != nullptr, error = kErrorNoBufs);
145 
146     SuccessOrExit(error = message->InitAsNonConfirmablePost(UriPath::kRelayRx));
147     SuccessOrExit(error = message->SetPayloadMarker());
148 
149     SuccessOrExit(error = Tlv::Append<JoinerUdpPortTlv>(*message, aMessageInfo.GetPeerPort()));
150     SuccessOrExit(error = Tlv::Append<JoinerIidTlv>(*message, aMessageInfo.GetPeerAddr().GetIid()));
151     SuccessOrExit(error = Tlv::Append<JoinerRouterLocatorTlv>(*message, Get<Mle::MleRouter>().GetRloc16()));
152 
153     tlv.SetType(Tlv::kJoinerDtlsEncapsulation);
154     tlv.SetLength(aMessage.GetLength() - aMessage.GetOffset());
155     SuccessOrExit(error = message->Append(tlv));
156     offset = message->GetLength();
157     SuccessOrExit(error = message->SetLength(offset + tlv.GetLength()));
158     aMessage.CopyTo(aMessage.GetOffset(), offset, tlv.GetLength(), *message);
159 
160     messageInfo.SetSockAddr(Get<Mle::MleRouter>().GetMeshLocal16());
161     messageInfo.SetPeerAddr(Get<Mle::MleRouter>().GetMeshLocal16());
162     messageInfo.GetPeerAddr().GetIid().SetLocator(borderAgentRloc);
163     messageInfo.SetPeerPort(Tmf::kUdpPort);
164 
165     SuccessOrExit(error = Get<Tmf::Agent>().SendMessage(*message, messageInfo));
166 
167     otLogInfoMeshCoP("Sent relay rx");
168 
169 exit:
170     FreeMessageOnError(message, error);
171 }
172 
HandleRelayTransmit(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo)173 void JoinerRouter::HandleRelayTransmit(void *aContext, otMessage *aMessage, const otMessageInfo *aMessageInfo)
174 {
175     static_cast<JoinerRouter *>(aContext)->HandleRelayTransmit(*static_cast<Coap::Message *>(aMessage),
176                                                                *static_cast<const Ip6::MessageInfo *>(aMessageInfo));
177 }
178 
HandleRelayTransmit(Coap::Message & aMessage,const Ip6::MessageInfo & aMessageInfo)179 void JoinerRouter::HandleRelayTransmit(Coap::Message &aMessage, const Ip6::MessageInfo &aMessageInfo)
180 {
181     OT_UNUSED_VARIABLE(aMessageInfo);
182 
183     Error                    error;
184     uint16_t                 joinerPort;
185     Ip6::InterfaceIdentifier joinerIid;
186     Kek                      kek;
187     uint16_t                 offset;
188     uint16_t                 length;
189     Message *                message = nullptr;
190     Message::Settings        settings(Message::kNoLinkSecurity, Message::kPriorityNet);
191     Ip6::MessageInfo         messageInfo;
192 
193     VerifyOrExit(aMessage.IsNonConfirmablePostRequest(), error = kErrorDrop);
194 
195     otLogInfoMeshCoP("Received relay transmit");
196 
197     SuccessOrExit(error = Tlv::Find<JoinerUdpPortTlv>(aMessage, joinerPort));
198     SuccessOrExit(error = Tlv::Find<JoinerIidTlv>(aMessage, joinerIid));
199 
200     SuccessOrExit(error = Tlv::FindTlvValueOffset(aMessage, Tlv::kJoinerDtlsEncapsulation, offset, length));
201 
202     VerifyOrExit((message = mSocket.NewMessage(0, settings)) != nullptr, error = kErrorNoBufs);
203 
204     SuccessOrExit(error = message->SetLength(length));
205     aMessage.CopyTo(offset, 0, length, *message);
206 
207     messageInfo.GetPeerAddr().SetToLinkLocalAddress(joinerIid);
208     messageInfo.SetPeerPort(joinerPort);
209 
210     SuccessOrExit(error = mSocket.SendTo(*message, messageInfo));
211 
212     if (Tlv::Find<JoinerRouterKekTlv>(aMessage, kek) == kErrorNone)
213     {
214         otLogInfoMeshCoP("Received kek");
215 
216         DelaySendingJoinerEntrust(messageInfo, kek);
217     }
218 
219 exit:
220     FreeMessageOnError(message, error);
221 }
222 
DelaySendingJoinerEntrust(const Ip6::MessageInfo & aMessageInfo,const Kek & aKek)223 void JoinerRouter::DelaySendingJoinerEntrust(const Ip6::MessageInfo &aMessageInfo, const Kek &aKek)
224 {
225     Error                 error   = kErrorNone;
226     Message *             message = Get<MessagePool>().New(Message::kTypeOther, 0);
227     JoinerEntrustMetadata metadata;
228 
229     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
230 
231     metadata.mMessageInfo = aMessageInfo;
232     metadata.mMessageInfo.SetPeerPort(Tmf::kUdpPort);
233     metadata.mSendTime = TimerMilli::GetNow() + kJoinerEntrustTxDelay;
234     metadata.mKek      = aKek;
235 
236     SuccessOrExit(error = metadata.AppendTo(*message));
237 
238     mDelayedJoinEnts.Enqueue(*message);
239 
240     if (!mTimer.IsRunning())
241     {
242         mTimer.FireAt(metadata.mSendTime);
243     }
244 
245 exit:
246     FreeMessageOnError(message, error);
247     LogError("schedule joiner entrust", error);
248 }
249 
HandleTimer(Timer & aTimer)250 void JoinerRouter::HandleTimer(Timer &aTimer)
251 {
252     aTimer.Get<JoinerRouter>().HandleTimer();
253 }
254 
HandleTimer(void)255 void JoinerRouter::HandleTimer(void)
256 {
257     SendDelayedJoinerEntrust();
258 }
259 
SendDelayedJoinerEntrust(void)260 void JoinerRouter::SendDelayedJoinerEntrust(void)
261 {
262     JoinerEntrustMetadata metadata;
263     Message *             message = mDelayedJoinEnts.GetHead();
264 
265     VerifyOrExit(message != nullptr);
266     VerifyOrExit(!mTimer.IsRunning());
267 
268     metadata.ReadFrom(*message);
269 
270     if (TimerMilli::GetNow() < metadata.mSendTime)
271     {
272         mTimer.FireAt(metadata.mSendTime);
273     }
274     else
275     {
276         mDelayedJoinEnts.DequeueAndFree(*message);
277 
278         Get<KeyManager>().SetKek(metadata.mKek);
279 
280         if (SendJoinerEntrust(metadata.mMessageInfo) != kErrorNone)
281         {
282             mTimer.Start(0);
283         }
284     }
285 
286 exit:
287     return;
288 }
289 
SendJoinerEntrust(const Ip6::MessageInfo & aMessageInfo)290 Error JoinerRouter::SendJoinerEntrust(const Ip6::MessageInfo &aMessageInfo)
291 {
292     Error          error = kErrorNone;
293     Coap::Message *message;
294 
295     message = PrepareJoinerEntrustMessage();
296     VerifyOrExit(message != nullptr, error = kErrorNoBufs);
297 
298     IgnoreError(Get<Tmf::Agent>().AbortTransaction(&JoinerRouter::HandleJoinerEntrustResponse, this));
299 
300     otLogInfoMeshCoP("Sending JOIN_ENT.ntf");
301     SuccessOrExit(error = Get<Tmf::Agent>().SendMessage(*message, aMessageInfo,
302                                                         &JoinerRouter::HandleJoinerEntrustResponse, this));
303 
304     otLogInfoMeshCoP("Sent joiner entrust length = %d", message->GetLength());
305     otLogCertMeshCoP("[THCI] direction=send | type=JOIN_ENT.ntf");
306 
307 exit:
308     FreeMessageOnError(message, error);
309     return error;
310 }
311 
PrepareJoinerEntrustMessage(void)312 Coap::Message *JoinerRouter::PrepareJoinerEntrustMessage(void)
313 {
314     Error          error;
315     Coap::Message *message = nullptr;
316     Dataset        dataset;
317 
318     NetworkNameTlv networkName;
319     const Tlv *    tlv;
320 
321     VerifyOrExit((message = NewMeshCoPMessage(Get<Tmf::Agent>())) != nullptr, error = kErrorNoBufs);
322 
323     message->InitAsConfirmablePost();
324     SuccessOrExit(error = message->AppendUriPathOptions(UriPath::kJoinerEntrust));
325     SuccessOrExit(error = message->SetPayloadMarker());
326     message->SetSubType(Message::kSubTypeJoinerEntrust);
327 
328     SuccessOrExit(error = Tlv::Append<NetworkKeyTlv>(*message, Get<KeyManager>().GetNetworkKey()));
329     SuccessOrExit(error = Tlv::Append<MeshLocalPrefixTlv>(*message, Get<Mle::MleRouter>().GetMeshLocalPrefix()));
330     SuccessOrExit(error = Tlv::Append<ExtendedPanIdTlv>(*message, Get<Mac::Mac>().GetExtendedPanId()));
331 
332     networkName.Init();
333     networkName.SetNetworkName(Get<Mac::Mac>().GetNetworkName().GetAsData());
334     SuccessOrExit(error = networkName.AppendTo(*message));
335 
336     IgnoreError(Get<ActiveDataset>().Read(dataset));
337 
338     if ((tlv = dataset.GetTlv<ActiveTimestampTlv>()) != nullptr)
339     {
340         SuccessOrExit(error = tlv->AppendTo(*message));
341     }
342     else
343     {
344         ActiveTimestampTlv activeTimestamp;
345         activeTimestamp.Init();
346         SuccessOrExit(error = activeTimestamp.AppendTo(*message));
347     }
348 
349     if ((tlv = dataset.GetTlv<ChannelMaskTlv>()) != nullptr)
350     {
351         SuccessOrExit(error = tlv->AppendTo(*message));
352     }
353     else
354     {
355         ChannelMaskBaseTlv channelMask;
356         channelMask.Init();
357         SuccessOrExit(error = channelMask.AppendTo(*message));
358     }
359 
360     if ((tlv = dataset.GetTlv<PskcTlv>()) != nullptr)
361     {
362         SuccessOrExit(error = tlv->AppendTo(*message));
363     }
364     else
365     {
366         PskcTlv pskc;
367         pskc.Init();
368         SuccessOrExit(error = pskc.AppendTo(*message));
369     }
370 
371     if ((tlv = dataset.GetTlv<SecurityPolicyTlv>()) != nullptr)
372     {
373         SuccessOrExit(error = tlv->AppendTo(*message));
374     }
375     else
376     {
377         SecurityPolicyTlv securityPolicy;
378         securityPolicy.Init();
379         SuccessOrExit(error = securityPolicy.AppendTo(*message));
380     }
381 
382     SuccessOrExit(error = Tlv::Append<NetworkKeySequenceTlv>(*message, Get<KeyManager>().GetCurrentKeySequence()));
383 
384 exit:
385     FreeAndNullMessageOnError(message, error);
386     return message;
387 }
388 
HandleJoinerEntrustResponse(void * aContext,otMessage * aMessage,const otMessageInfo * aMessageInfo,Error aResult)389 void JoinerRouter::HandleJoinerEntrustResponse(void *               aContext,
390                                                otMessage *          aMessage,
391                                                const otMessageInfo *aMessageInfo,
392                                                Error                aResult)
393 {
394     static_cast<JoinerRouter *>(aContext)->HandleJoinerEntrustResponse(
395         static_cast<Coap::Message *>(aMessage), static_cast<const Ip6::MessageInfo *>(aMessageInfo), aResult);
396 }
397 
HandleJoinerEntrustResponse(Coap::Message * aMessage,const Ip6::MessageInfo * aMessageInfo,Error aResult)398 void JoinerRouter::HandleJoinerEntrustResponse(Coap::Message *         aMessage,
399                                                const Ip6::MessageInfo *aMessageInfo,
400                                                Error                   aResult)
401 {
402     OT_UNUSED_VARIABLE(aMessageInfo);
403 
404     SendDelayedJoinerEntrust();
405 
406     VerifyOrExit(aResult == kErrorNone && aMessage != nullptr);
407 
408     VerifyOrExit(aMessage->GetCode() == Coap::kCodeChanged);
409 
410     otLogInfoMeshCoP("Receive joiner entrust response");
411     otLogCertMeshCoP("[THCI] direction=recv | type=JOIN_ENT.rsp");
412 
413 exit:
414     return;
415 }
416 
ReadFrom(const Message & aMessage)417 void JoinerRouter::JoinerEntrustMetadata::ReadFrom(const Message &aMessage)
418 {
419     uint16_t length = aMessage.GetLength();
420 
421     OT_ASSERT(length >= sizeof(*this));
422     IgnoreError(aMessage.Read(length - sizeof(*this), *this));
423 }
424 
425 } // namespace MeshCoP
426 } // namespace ot
427 
428 #endif // OPENTHREAD_FTD
429