1# Lint as: python3
2# coding=utf-8
3# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#         http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ==============================================================================
17"""Mix and split data.
18
19Mix different people's data together and randomly split them into train,
20validation and test. These data would be saved separately under "/data".
21It will generate new files with the following structure:
22
23├── data
24│   ├── complete_data
25│   ├── test
26│   ├── train
27│   └── valid
28"""
29
30from __future__ import absolute_import
31from __future__ import division
32from __future__ import print_function
33
34import json
35import random
36from data_prepare import write_data
37
38
39# Read data
40def read_data(path):
41    data = []
42    with open(path, "r") as f:
43        lines = f.readlines()
44        for idx, line in enumerate(lines):    # pylint: disable=unused-variable
45            dic = json.loads(line)
46            data.append(dic)
47    print("data_length:" + str(len(data)))
48    return data
49
50
51def split_data(data, train_ratio, valid_ratio):
52    """Splits data into train, validation and test according to ratio."""
53    train_data = []
54    valid_data = []
55    test_data = []
56    num_dic = {"wing": 0, "ring": 0, "slope": 0, "negative": 0}
57    for idx, item in enumerate(data):    # pylint: disable=unused-variable
58        for i in num_dic:
59            if item["gesture"] == i:
60                num_dic[i] += 1
61    print(num_dic)
62    train_num_dic = {}
63    valid_num_dic = {}
64    for i in num_dic:
65        train_num_dic[i] = int(train_ratio * num_dic[i])
66        valid_num_dic[i] = int(valid_ratio * num_dic[i])
67    random.seed(30)
68    random.shuffle(data)
69    for idx, item in enumerate(data):
70        for i in num_dic:
71            if item["gesture"] == i:
72                if train_num_dic[i] > 0:
73                    train_data.append(item)
74                    train_num_dic[i] -= 1
75                elif valid_num_dic[i] > 0:
76                    valid_data.append(item)
77                    valid_num_dic[i] -= 1
78                else:
79                    test_data.append(item)
80    print("train_length:" + str(len(train_data)))
81    print("test_length:" + str(len(test_data)))
82    return train_data, valid_data, test_data
83
84
85if __name__ == "__main__":
86    data = read_data("./data/complete_data")
87    train_data, valid_data, test_data = split_data(data, 0.6, 0.2)
88    write_data(train_data, "./data/train")
89    write_data(valid_data, "./data/valid")
90    write_data(test_data, "./data/test")
91