1# Lint as: python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16# pylint: disable=g-bad-import-order 17 18"""Build and train neural networks.""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import argparse 25import datetime 26import os # pylint: disable=duplicate-code 27from data_load import DataLoader 28 29import numpy as np # pylint: disable=duplicate-code 30import tensorflow as tf 31 32logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") 33tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir) 34 35 36def reshape_function(data, label): 37 reshaped_data = tf.reshape(data, [-1, 3, 1]) 38 return reshaped_data, label 39 40 41def calculate_model_size(model): 42 print(model.summary()) 43 var_sizes = [ 44 np.product(list(map(int, v.shape))) * v.dtype.size 45 for v in model.trainable_variables 46 ] 47 print("Model size:", sum(var_sizes) / 1024, "KB") 48 49 50def build_cnn(seq_length): 51 """Builds a convolutional neural network in Keras.""" 52 model = tf.keras.Sequential([ 53 tf.keras.layers.Conv2D( 54 8, (4, 3), 55 padding="same", 56 activation="relu", 57 input_shape=(seq_length, 3, 1)), # output_shape=(batch, 128, 3, 8) 58 tf.keras.layers.MaxPool2D((3, 3)), # (batch, 42, 1, 8) 59 tf.keras.layers.Dropout(0.1), # (batch, 42, 1, 8) 60 tf.keras.layers.Conv2D(16, (4, 1), padding="same", 61 activation="relu"), # (batch, 42, 1, 16) 62 tf.keras.layers.MaxPool2D((3, 1), padding="same"), # (batch, 14, 1, 16) 63 tf.keras.layers.Dropout(0.1), # (batch, 14, 1, 16) 64 tf.keras.layers.Flatten(), # (batch, 224) 65 tf.keras.layers.Dense(16, activation="relu"), # (batch, 16) 66 tf.keras.layers.Dropout(0.1), # (batch, 16) 67 tf.keras.layers.Dense(4, activation="softmax") # (batch, 4) 68 ]) 69 model_path = os.path.join("./netmodels", "CNN") 70 print("Built CNN.") 71 if not os.path.exists(model_path): 72 os.makedirs(model_path) 73 model.load_weights("./netmodels/CNN/weights.h5") 74 return model, model_path 75 76 77def build_lstm(seq_length): 78 """Builds an LSTM in Keras.""" 79 model = tf.keras.Sequential([ 80 tf.keras.layers.Bidirectional( 81 tf.keras.layers.LSTM(22), 82 input_shape=(seq_length, 3)), # output_shape=(batch, 44) 83 tf.keras.layers.Dense(4, activation="sigmoid") # (batch, 4) 84 ]) 85 model_path = os.path.join("./netmodels", "LSTM") 86 print("Built LSTM.") 87 if not os.path.exists(model_path): 88 os.makedirs(model_path) 89 return model, model_path 90 91 92def load_data(train_data_path, valid_data_path, test_data_path, seq_length): 93 data_loader = DataLoader( 94 train_data_path, valid_data_path, test_data_path, seq_length=seq_length) 95 data_loader.format() 96 return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \ 97 data_loader.valid_data, data_loader.test_len, data_loader.test_data 98 99 100def build_net(args, seq_length): 101 if args.model == "CNN": 102 model, model_path = build_cnn(seq_length) 103 elif args.model == "LSTM": 104 model, model_path = build_lstm(seq_length) 105 else: 106 print("Please input correct model name.(CNN LSTM)") 107 return model, model_path 108 109 110def train_net( 111 model, 112 model_path, # pylint: disable=unused-argument 113 train_len, # pylint: disable=unused-argument 114 train_data, 115 valid_len, 116 valid_data, 117 test_len, 118 test_data, 119 kind): 120 """Trains the model.""" 121 calculate_model_size(model) 122 epochs = 50 123 batch_size = 64 124 model.compile( 125 optimizer="adam", 126 loss="sparse_categorical_crossentropy", 127 metrics=["accuracy"]) 128 if kind == "CNN": 129 train_data = train_data.map(reshape_function) 130 test_data = test_data.map(reshape_function) 131 valid_data = valid_data.map(reshape_function) 132 test_labels = np.zeros(test_len) 133 idx = 0 134 for data, label in test_data: # pylint: disable=unused-variable 135 test_labels[idx] = label.numpy() 136 idx += 1 137 train_data = train_data.batch(batch_size).repeat() 138 valid_data = valid_data.batch(batch_size) 139 test_data = test_data.batch(batch_size) 140 model.fit( 141 train_data, 142 epochs=epochs, 143 validation_data=valid_data, 144 steps_per_epoch=1000, 145 validation_steps=int((valid_len - 1) / batch_size + 1), 146 callbacks=[tensorboard_callback]) 147 loss, acc = model.evaluate(test_data) 148 pred = np.argmax(model.predict(test_data), axis=1) 149 confusion = tf.math.confusion_matrix( 150 labels=tf.constant(test_labels), 151 predictions=tf.constant(pred), 152 num_classes=4) 153 print(confusion) 154 print("Loss {}, Accuracy {}".format(loss, acc)) 155 # Convert the model to the TensorFlow Lite format without quantization 156 converter = tf.lite.TFLiteConverter.from_keras_model(model) 157 tflite_model = converter.convert() 158 159 # Save the model to disk 160 open("model.tflite", "wb").write(tflite_model) 161 162 # Convert the model to the TensorFlow Lite format with quantization 163 converter = tf.lite.TFLiteConverter.from_keras_model(model) 164 converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] 165 tflite_model = converter.convert() 166 167 # Save the model to disk 168 open("model_quantized.tflite", "wb").write(tflite_model) 169 170 basic_model_size = os.path.getsize("model.tflite") 171 print("Basic model is %d bytes" % basic_model_size) 172 quantized_model_size = os.path.getsize("model_quantized.tflite") 173 print("Quantized model is %d bytes" % quantized_model_size) 174 difference = basic_model_size - quantized_model_size 175 print("Difference is %d bytes" % difference) 176 177 178if __name__ == "__main__": 179 parser = argparse.ArgumentParser(allow_abbrev=False) 180 parser.add_argument("--model", "-m") 181 parser.add_argument("--person", "-p") 182 args = parser.parse_args() 183 184 seq_length = 128 185 186 print("Start to load data...") 187 if args.person == "true": 188 train_len, train_data, valid_len, valid_data, test_len, test_data = \ 189 load_data("./person_split/train", "./person_split/valid", 190 "./person_split/test", seq_length) 191 else: 192 train_len, train_data, valid_len, valid_data, test_len, test_data = \ 193 load_data("./data/train", "./data/valid", "./data/test", seq_length) 194 195 print("Start to build net...") 196 model, model_path = build_net(args, seq_length) 197 198 print("Start training...") 199 train_net(model, model_path, train_len, train_data, valid_len, valid_data, 200 test_len, test_data, args.model) 201 202 print("Training finished!") 203