1# Lint as: python3 2# coding=utf-8 3# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================== 17"""Mix and split data. 18 19Mix different people's data together and randomly split them into train, 20validation and test. These data would be saved separately under "/data". 21It will generate new files with the following structure: 22 23├── data 24│ ├── complete_data 25│ ├── test 26│ ├── train 27│ └── valid 28""" 29 30from __future__ import absolute_import 31from __future__ import division 32from __future__ import print_function 33 34import json 35import random 36from data_prepare import write_data 37 38 39# Read data 40def read_data(path): 41 data = [] 42 with open(path, "r") as f: 43 lines = f.readlines() 44 for idx, line in enumerate(lines): # pylint: disable=unused-variable 45 dic = json.loads(line) 46 data.append(dic) 47 print("data_length:" + str(len(data))) 48 return data 49 50 51def split_data(data, train_ratio, valid_ratio): 52 """Splits data into train, validation and test according to ratio.""" 53 train_data = [] 54 valid_data = [] 55 test_data = [] 56 num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0} 57 for idx, item in enumerate(data): # pylint: disable=unused-variable 58 for i in num_dic: 59 if item["gesture"] == i: 60 num_dic[i] += 1 61 print(num_dic) 62 train_num_dic = {} 63 valid_num_dic = {} 64 for i in num_dic: 65 train_num_dic[i] = int(train_ratio * num_dic[i]) 66 valid_num_dic[i] = int(valid_ratio * num_dic[i]) 67 random.seed(30) 68 random.shuffle(data) 69 for idx, item in enumerate(data): 70 for i in num_dic: 71 if item["gesture"] == i: 72 if train_num_dic[i] > 0: 73 train_data.append(item) 74 train_num_dic[i] -= 1 75 elif valid_num_dic[i] > 0: 76 valid_data.append(item) 77 valid_num_dic[i] -= 1 78 else: 79 test_data.append(item) 80 print("train_length:" + str(len(train_data))) 81 print("test_length:" + str(len(test_data))) 82 return train_data, valid_data, test_data 83 84 85if __name__ == "__main__": 86 data = read_data("./data/complete_data") 87 train_data, valid_data, test_data = split_data(data, 0.6, 0.2) 88 write_data(train_data, "./data/train") 89 write_data(valid_data, "./data/valid") 90 write_data(test_data, "./data/test") 91