1 /*
2  * Copyright (c) 2023 Nordic Semiconductor ASA
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <stdint.h>
8 #include <string.h>
9 
10 #include <zephyr/kernel.h>
11 
12 #include <zephyr/net_buf.h>
13 #include <zephyr/sys/printk.h>
14 #include <zephyr/logging/log.h>
15 
16 #include <zephyr/bluetooth/ead.h>
17 #include <zephyr/bluetooth/att.h>
18 #include <zephyr/bluetooth/gatt.h>
19 #include <zephyr/bluetooth/addr.h>
20 #include <zephyr/bluetooth/conn.h>
21 #include <zephyr/bluetooth/uuid.h>
22 #include <zephyr/bluetooth/hci.h>
23 #include <zephyr/bluetooth/bluetooth.h>
24 
25 #include "common/bt_str.h"
26 
27 #include "common.h"
28 
29 LOG_MODULE_REGISTER(ead_central_sample, CONFIG_BT_EAD_LOG_LEVEL);
30 
31 static uint8_t *received_data;
32 static size_t received_data_size;
33 static struct key_material keymat;
34 
35 static bt_addr_le_t peer_addr;
36 static struct bt_conn *default_conn;
37 
38 static struct bt_conn_cb central_cb;
39 static struct bt_conn_auth_cb central_auth_cb;
40 
41 static struct k_poll_signal conn_signal;
42 static struct k_poll_signal passkey_enter_signal;
43 static struct k_poll_signal device_found_cb_completed;
44 
45 /* GATT Discover data */
46 static uint8_t gatt_disc_err;
47 static uint16_t gatt_disc_end_handle;
48 static uint16_t gatt_disc_start_handle;
49 static struct k_poll_signal gatt_disc_signal;
50 
51 /* GATT Read data */
52 static uint8_t gatt_read_err;
53 static uint8_t *gatt_read_res;
54 static uint16_t gatt_read_len;
55 static uint16_t gatt_read_handle;
56 static struct k_poll_signal gatt_read_signal;
57 
data_parse_cb(struct bt_data * data,void * user_data)58 static bool data_parse_cb(struct bt_data *data, void *user_data)
59 {
60 	size_t *parsed_data_size = (size_t *)user_data;
61 
62 	if (data->type == BT_DATA_ENCRYPTED_AD_DATA) {
63 		int err;
64 		struct net_buf_simple decrypted_buf;
65 		size_t decrypted_data_size = BT_EAD_DECRYPTED_PAYLOAD_SIZE(data->data_len);
66 		uint8_t decrypted_data[decrypted_data_size];
67 
68 		err = bt_ead_decrypt(keymat.session_key, keymat.iv, data->data, data->data_len,
69 				     decrypted_data);
70 		if (err < 0) {
71 			LOG_ERR("Error during decryption (err %d)", err);
72 		}
73 
74 		net_buf_simple_init_with_data(&decrypted_buf, &decrypted_data[0],
75 					      decrypted_data_size);
76 
77 		bt_data_parse(&decrypted_buf, &data_parse_cb, user_data);
78 	} else {
79 		LOG_INF("len : %u", data->data_len);
80 		LOG_INF("type: 0x%02x", data->type);
81 		LOG_HEXDUMP_INF(data->data, data->data_len, "data:");
82 
83 		/* Copy the data out if we are running in a test context */
84 		if (received_data != NULL) {
85 			if (bt_data_get_len(data, 1) <=
86 			    (received_data_size - (*parsed_data_size))) {
87 				*parsed_data_size +=
88 					bt_data_serialize(data, &received_data[*parsed_data_size]);
89 			}
90 		} else {
91 			*parsed_data_size += bt_data_get_len(data, 1);
92 		}
93 	}
94 
95 	return true;
96 }
97 
device_found(const bt_addr_le_t * addr,int8_t rssi,uint8_t type,struct net_buf_simple * ad)98 static void device_found(const bt_addr_le_t *addr, int8_t rssi, uint8_t type,
99 			 struct net_buf_simple *ad)
100 {
101 	int err;
102 	size_t parsed_data_size;
103 	char addr_str[BT_ADDR_LE_STR_LEN];
104 
105 	if (default_conn) {
106 		return;
107 	}
108 
109 	bt_addr_le_to_str(addr, addr_str, sizeof(addr_str));
110 
111 	/* We are only interested in the previously connected device. */
112 	if (!bt_addr_le_eq(addr, &peer_addr)) {
113 		return;
114 	}
115 
116 	LOG_DBG("Peer found.");
117 
118 	parsed_data_size = 0;
119 	LOG_INF("Received data size: %zu", ad->len);
120 	bt_data_parse(ad, data_parse_cb, &parsed_data_size);
121 
122 	LOG_DBG("All data parsed. (total size: %zu)", parsed_data_size);
123 
124 	err = bt_le_scan_stop();
125 	if (err) {
126 		LOG_DBG("Failed to stop scanner (err %d)", err);
127 		return;
128 	}
129 
130 	k_poll_signal_raise(&device_found_cb_completed, 0);
131 }
132 
connect_device_found(const bt_addr_le_t * addr,int8_t rssi,uint8_t type,struct net_buf_simple * ad)133 static void connect_device_found(const bt_addr_le_t *addr, int8_t rssi, uint8_t type,
134 				 struct net_buf_simple *ad)
135 {
136 	int err;
137 	char addr_str[BT_ADDR_LE_STR_LEN];
138 
139 	if (default_conn) {
140 		return;
141 	}
142 
143 	/* Connect only to devices in close range */
144 	if (rssi < -70) {
145 		return;
146 	}
147 
148 	bt_addr_le_to_str(addr, addr_str, sizeof(addr_str));
149 
150 	LOG_DBG("Device found: %s (RSSI %d)", addr_str, rssi);
151 
152 	err = bt_le_scan_stop();
153 	if (err) {
154 		LOG_DBG("Failed to stop scanner (err %d)", err);
155 		return;
156 	}
157 
158 	err = bt_conn_le_create(addr, BT_CONN_LE_CREATE_CONN, BT_LE_CONN_PARAM_DEFAULT,
159 				&default_conn);
160 	if (err) {
161 		LOG_DBG("Failed to connect to %s (err %d)", addr_str, err);
162 		return;
163 	}
164 
165 	k_poll_signal_raise(&device_found_cb_completed, 0);
166 }
167 
start_scan(bool connect)168 static int start_scan(bool connect)
169 {
170 	int err;
171 
172 	k_poll_signal_reset(&conn_signal);
173 	k_poll_signal_reset(&device_found_cb_completed);
174 
175 	err = bt_le_scan_start(BT_LE_SCAN_PASSIVE, connect ? connect_device_found : device_found);
176 	if (err) {
177 		LOG_DBG("Scanning failed to start (err %d)", err);
178 		return -1;
179 	}
180 
181 	LOG_DBG("Scanning started.");
182 
183 	if (connect) {
184 		LOG_DBG("Waiting for connection");
185 		await_signal(&conn_signal);
186 	}
187 
188 	await_signal(&device_found_cb_completed);
189 
190 	return 0;
191 }
192 
gatt_read_cb(struct bt_conn * conn,uint8_t att_err,struct bt_gatt_read_params * params,const void * data,uint16_t read_len)193 static uint8_t gatt_read_cb(struct bt_conn *conn, uint8_t att_err,
194 			    struct bt_gatt_read_params *params, const void *data, uint16_t read_len)
195 {
196 	gatt_read_err = att_err;
197 	gatt_read_len = read_len;
198 	gatt_read_handle = params->by_uuid.start_handle;
199 
200 	if (!att_err) {
201 		memcpy(gatt_read_res, data, read_len);
202 		k_poll_signal_raise(&gatt_read_signal, 0);
203 	} else {
204 		LOG_ERR("ATT error (err %d)", att_err);
205 	}
206 
207 	return BT_GATT_ITER_STOP;
208 }
209 
gatt_read(struct bt_conn * conn,const struct bt_uuid * uuid,size_t read_size,uint16_t start_handle,uint16_t end_handle,uint8_t * buf)210 static int gatt_read(struct bt_conn *conn, const struct bt_uuid *uuid, size_t read_size,
211 		     uint16_t start_handle, uint16_t end_handle, uint8_t *buf)
212 {
213 	int err;
214 	size_t offset;
215 	uint16_t handle;
216 	struct bt_gatt_read_params params;
217 
218 	gatt_read_res = &buf[0];
219 
220 	params.handle_count = 0;
221 	params.by_uuid.start_handle = start_handle;
222 	params.by_uuid.end_handle = end_handle;
223 	params.by_uuid.uuid = uuid;
224 	params.func = gatt_read_cb;
225 
226 	k_poll_signal_reset(&gatt_read_signal);
227 
228 	err = bt_gatt_read(conn, &params);
229 	if (err) {
230 		LOG_DBG("GATT read failed (err %d)", err);
231 		return -1;
232 	}
233 
234 	await_signal(&gatt_read_signal);
235 
236 	offset = gatt_read_len;
237 	handle = gatt_read_handle;
238 
239 	while (offset < read_size) {
240 		gatt_read_res = &buf[offset];
241 
242 		params.handle_count = 1;
243 		params.single.handle = handle;
244 		params.single.offset = offset;
245 
246 		k_poll_signal_reset(&gatt_read_signal);
247 
248 		err = bt_gatt_read(conn, &params);
249 		if (err) {
250 			LOG_DBG("GATT read failed (err %d)", err);
251 			return -1;
252 		}
253 
254 		await_signal(&gatt_read_signal);
255 
256 		offset += gatt_read_len;
257 	}
258 
259 	return 0;
260 }
261 
gatt_discover_cb(struct bt_conn * conn,const struct bt_gatt_attr * attr,struct bt_gatt_discover_params * params)262 static uint8_t gatt_discover_cb(struct bt_conn *conn, const struct bt_gatt_attr *attr,
263 				struct bt_gatt_discover_params *params)
264 {
265 	gatt_disc_err = attr ? 0 : BT_ATT_ERR_ATTRIBUTE_NOT_FOUND;
266 
267 	if (attr) {
268 		gatt_disc_start_handle = attr->handle;
269 		gatt_disc_end_handle = ((struct bt_gatt_service_val *)attr->user_data)->end_handle;
270 	}
271 
272 	k_poll_signal_raise(&gatt_disc_signal, 0);
273 
274 	return BT_GATT_ITER_STOP;
275 }
276 
gatt_discover_primary_service(struct bt_conn * conn,const struct bt_uuid * service_type,uint16_t * start_handle,uint16_t * end_handle)277 static int gatt_discover_primary_service(struct bt_conn *conn, const struct bt_uuid *service_type,
278 					 uint16_t *start_handle, uint16_t *end_handle)
279 {
280 	int err;
281 	struct bt_gatt_discover_params params;
282 
283 	params.type = BT_GATT_DISCOVER_PRIMARY;
284 	params.start_handle = BT_ATT_FIRST_ATTRIBUTE_HANDLE;
285 	params.end_handle = BT_ATT_LAST_ATTRIBUTE_HANDLE;
286 	params.uuid = service_type;
287 	params.func = gatt_discover_cb;
288 
289 	k_poll_signal_reset(&gatt_disc_signal);
290 
291 	err = bt_gatt_discover(conn, &params);
292 	if (err) {
293 		LOG_DBG("Primary service discover failed (err %d)", err);
294 		return -1;
295 	}
296 
297 	await_signal(&gatt_disc_signal);
298 
299 	*start_handle = gatt_disc_start_handle;
300 	*end_handle = gatt_disc_end_handle;
301 
302 	return gatt_disc_err;
303 }
304 
connected(struct bt_conn * conn,uint8_t conn_err)305 static void connected(struct bt_conn *conn, uint8_t conn_err)
306 {
307 	char addr[BT_ADDR_LE_STR_LEN];
308 
309 	bt_addr_le_to_str(bt_conn_get_dst(conn), addr, sizeof(addr));
310 
311 	if (conn_err) {
312 		LOG_DBG("Failed to connect to %s (err %u)", addr, conn_err);
313 
314 		bt_conn_unref(default_conn);
315 		default_conn = NULL;
316 
317 		(void)start_scan(true);
318 
319 		return;
320 	}
321 
322 	LOG_DBG("Connected to: %s", addr);
323 
324 	k_poll_signal_raise(&conn_signal, 0);
325 }
326 
disconnected(struct bt_conn * conn,uint8_t reason)327 static void disconnected(struct bt_conn *conn, uint8_t reason)
328 {
329 	char addr[BT_ADDR_LE_STR_LEN];
330 
331 	bt_addr_le_to_str(bt_conn_get_dst(conn), addr, sizeof(addr));
332 
333 	LOG_DBG("Disconnected: %s, reason 0x%02x %s", addr, reason, bt_hci_err_to_str(reason));
334 
335 	if (default_conn != conn) {
336 		return;
337 	}
338 
339 	bt_conn_unref(default_conn);
340 	default_conn = NULL;
341 }
342 
security_changed(struct bt_conn * conn,bt_security_t level,enum bt_security_err err)343 static void security_changed(struct bt_conn *conn, bt_security_t level, enum bt_security_err err)
344 {
345 	char addr[BT_ADDR_LE_STR_LEN];
346 
347 	bt_addr_le_to_str(bt_conn_get_dst(conn), addr, sizeof(addr));
348 
349 	if (!err) {
350 		LOG_DBG("Security changed: %s level %u", addr, level);
351 	} else {
352 		LOG_DBG("Security failed: %s level %u err %d %s", addr, level,
353 			err, bt_security_err_to_str(err));
354 	}
355 }
356 
identity_resolved(struct bt_conn * conn,const bt_addr_le_t * rpa,const bt_addr_le_t * identity)357 static void identity_resolved(struct bt_conn *conn, const bt_addr_le_t *rpa,
358 			      const bt_addr_le_t *identity)
359 {
360 	char addr_identity[BT_ADDR_LE_STR_LEN];
361 	char addr_rpa[BT_ADDR_LE_STR_LEN];
362 
363 	bt_addr_le_to_str(identity, addr_identity, sizeof(addr_identity));
364 	bt_addr_le_to_str(rpa, addr_rpa, sizeof(addr_rpa));
365 
366 	LOG_DBG("Identity resolved %s -> %s", addr_rpa, addr_identity);
367 
368 	bt_addr_le_copy(&peer_addr, identity);
369 }
370 
auth_passkey_confirm(struct bt_conn * conn,unsigned int passkey)371 static void auth_passkey_confirm(struct bt_conn *conn, unsigned int passkey)
372 {
373 	char passkey_str[7];
374 	char addr[BT_ADDR_LE_STR_LEN];
375 
376 	bt_addr_le_to_str(bt_conn_get_dst(conn), addr, sizeof(addr));
377 
378 	snprintk(passkey_str, ARRAY_SIZE(passkey_str), "%06u", passkey);
379 
380 	printk("Passkey for %s: %s\n", addr, passkey_str);
381 
382 	k_poll_signal_raise(&passkey_enter_signal, 0);
383 }
384 
auth_passkey_display(struct bt_conn * conn,unsigned int passkey)385 static void auth_passkey_display(struct bt_conn *conn, unsigned int passkey)
386 {
387 	char passkey_str[7];
388 	char addr[BT_ADDR_LE_STR_LEN];
389 
390 	bt_addr_le_to_str(bt_conn_get_dst(conn), addr, sizeof(addr));
391 
392 	snprintk(passkey_str, ARRAY_SIZE(passkey_str), "%06u", passkey);
393 
394 	LOG_DBG("Passkey for %s: %s", addr, passkey_str);
395 }
396 
auth_cancel(struct bt_conn * conn)397 static void auth_cancel(struct bt_conn *conn)
398 {
399 	char addr[BT_ADDR_LE_STR_LEN];
400 
401 	bt_addr_le_to_str(bt_conn_get_dst(conn), addr, sizeof(addr));
402 
403 	LOG_DBG("Pairing cancelled: %s", addr);
404 }
405 
init_bt(void)406 static int init_bt(void)
407 {
408 	int err;
409 
410 	default_conn = NULL;
411 
412 	k_poll_signal_init(&conn_signal);
413 	k_poll_signal_init(&passkey_enter_signal);
414 	k_poll_signal_init(&gatt_disc_signal);
415 	k_poll_signal_init(&gatt_read_signal);
416 	k_poll_signal_init(&device_found_cb_completed);
417 
418 	err = bt_enable(NULL);
419 	if (err) {
420 		LOG_ERR("Bluetooth init failed (err %d)", err);
421 		return -1;
422 	}
423 
424 	LOG_DBG("Bluetooth initialized");
425 
426 	err = bt_unpair(BT_ID_DEFAULT, BT_ADDR_LE_ANY);
427 	if (err) {
428 		LOG_ERR("Unpairing failed (err %d)", err);
429 	}
430 
431 	central_cb.connected = connected;
432 	central_cb.disconnected = disconnected;
433 	central_cb.security_changed = security_changed;
434 	central_cb.identity_resolved = identity_resolved;
435 
436 	bt_conn_cb_register(&central_cb);
437 
438 	central_auth_cb.pairing_confirm = NULL;
439 	central_auth_cb.passkey_confirm = auth_passkey_confirm;
440 	central_auth_cb.passkey_display = auth_passkey_display;
441 	central_auth_cb.passkey_entry = NULL;
442 	central_auth_cb.oob_data_request = NULL;
443 	central_auth_cb.cancel = auth_cancel;
444 
445 	err = bt_conn_auth_cb_register(&central_auth_cb);
446 	if (err) {
447 		return -1;
448 	}
449 
450 	return 0;
451 }
452 
run_central_sample(int get_passkey_confirmation (struct bt_conn * conn),uint8_t * test_received_data,size_t test_received_data_size,struct key_material * test_received_keymat)453 int run_central_sample(int get_passkey_confirmation(struct bt_conn *conn),
454 		       uint8_t *test_received_data, size_t test_received_data_size,
455 		       struct key_material *test_received_keymat)
456 {
457 	int err;
458 	bool connect;
459 	uint16_t end_handle;
460 	uint16_t start_handle;
461 
462 	if (test_received_data != NULL) {
463 		received_data = test_received_data;
464 		received_data_size = test_received_data_size;
465 	}
466 
467 	/* Initialize Bluetooth and callbacks */
468 	err = init_bt();
469 	if (err) {
470 		return -1;
471 	}
472 
473 	/* Start scan and connect to our peripheral */
474 	connect = true;
475 	err = start_scan(connect);
476 	if (err) {
477 		return -2;
478 	}
479 
480 	/* Update connection security level */
481 	err = bt_conn_set_security(default_conn, BT_SECURITY_L4);
482 	if (err) {
483 		LOG_ERR("Failed to set security (err %d)", err);
484 		return -3;
485 	}
486 
487 	await_signal(&passkey_enter_signal);
488 
489 	err = get_passkey_confirmation(default_conn);
490 	if (err) {
491 		LOG_ERR("Security update failed");
492 		return -4;
493 	}
494 
495 	/* Locate the primary service */
496 	err = gatt_discover_primary_service(default_conn, BT_UUID_CUSTOM_SERVICE, &start_handle,
497 					    &end_handle);
498 	if (err) {
499 		LOG_ERR("Service not found (err %d)", err);
500 		return -5;
501 	}
502 
503 	/* Read the Key Material characteristic */
504 	err = gatt_read(default_conn, BT_UUID_GATT_EDKM, sizeof(keymat), start_handle, end_handle,
505 			(uint8_t *)&keymat);
506 	if (err) {
507 		LOG_ERR("GATT read failed (err %d)", err);
508 		return -6;
509 	}
510 
511 	LOG_HEXDUMP_DBG(keymat.session_key, BT_EAD_KEY_SIZE, "Session Key");
512 	LOG_HEXDUMP_DBG(keymat.iv, BT_EAD_IV_SIZE, "IV");
513 
514 	if (test_received_keymat != NULL) {
515 		memcpy(test_received_keymat, &keymat, sizeof(keymat));
516 	}
517 
518 	/* Start a new scan to get and decrypt the Advertising Data */
519 	err = bt_conn_disconnect(default_conn, BT_HCI_ERR_REMOTE_USER_TERM_CONN);
520 	if (err) {
521 		LOG_ERR("Failed to disconnect.");
522 		return -7;
523 	}
524 
525 	connect = false;
526 	err = start_scan(connect);
527 	if (err) {
528 		return -2;
529 	}
530 
531 	return 0;
532 }
533