1# Lint as: python3 2# coding=utf-8 3# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# ============================================================================== 17 18"""Split data into train, validation and test dataset according to person. 19 20That is, use some people's data as train, some other people's data as 21validation, and the rest ones' data as test. These data would be saved 22separately under "/person_split". 23 24It will generate new files with the following structure: 25├──person_split 26│ ├── test 27│ ├── train 28│ └──valid 29""" 30 31from __future__ import absolute_import 32from __future__ import division 33from __future__ import print_function 34 35import os 36import random 37from data_split import read_data 38from data_split import write_data 39 40 41def person_split(whole_data, train_names, valid_names, test_names): 42 """Split data by person.""" 43 random.seed(30) 44 random.shuffle(whole_data) 45 train_data = [] 46 valid_data = [] 47 test_data = [] 48 for idx, data in enumerate(whole_data): # pylint: disable=unused-variable 49 if data["name"] in train_names: 50 train_data.append(data) 51 elif data["name"] in valid_names: 52 valid_data.append(data) 53 elif data["name"] in test_names: 54 test_data.append(data) 55 print("train_length:" + str(len(train_data))) 56 print("valid_length:" + str(len(valid_data))) 57 print("test_length:" + str(len(test_data))) 58 return train_data, valid_data, test_data 59 60 61if __name__ == "__main__": 62 data = read_data("./data/complete_data") 63 train_names = [ 64 "hyw", "shiyun", "tangsy", "dengyl", "jiangyh", "xunkai", "negative3", 65 "negative4", "negative5", "negative6" 66 ] 67 valid_names = ["lsj", "pengxl", "negative2", "negative7"] 68 test_names = ["liucx", "zhangxy", "negative1", "negative8"] 69 train_data, valid_data, test_data = person_split(data, train_names, 70 valid_names, test_names) 71 if not os.path.exists("./person_split"): 72 os.makedirs("./person_split") 73 write_data(train_data, "./person_split/train") 74 write_data(valid_data, "./person_split/valid") 75 write_data(test_data, "./person_split/test") 76