1# Lint as: python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#         http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16# pylint: disable=g-bad-import-order
17
18"""Build and train neural networks."""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import argparse
25import datetime
26import os # pylint: disable=duplicate-code
27from data_load import DataLoader
28
29import numpy as np # pylint: disable=duplicate-code
30import tensorflow as tf
31
32logdir = "logs/scalars/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
33tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
34
35
36def reshape_function(data, label):
37    reshaped_data = tf.reshape(data, [-1, 3, 1])
38    return reshaped_data, label
39
40
41def calculate_model_size(model):
42    print(model.summary())
43    var_sizes = [
44            np.product(list(map(int, v.shape))) * v.dtype.size
45            for v in model.trainable_variables
46    ]
47    print("Model size:", sum(var_sizes) / 1024, "KB")
48
49
50def build_cnn(seq_length):
51    """Builds a convolutional neural network in Keras."""
52    model = tf.keras.Sequential([
53            tf.keras.layers.Conv2D(
54                    8, (4, 3),
55                    padding="same",
56                    activation="relu",
57                    input_shape=(seq_length, 3, 1)),    # output_shape=(batch, 128, 3, 8)
58            tf.keras.layers.MaxPool2D((3, 3)),    # (batch, 42, 1, 8)
59            tf.keras.layers.Dropout(0.1),    # (batch, 42, 1, 8)
60            tf.keras.layers.Conv2D(16, (4, 1), padding="same",
61                                                         activation="relu"),    # (batch, 42, 1, 16)
62            tf.keras.layers.MaxPool2D((3, 1), padding="same"),    # (batch, 14, 1, 16)
63            tf.keras.layers.Dropout(0.1),    # (batch, 14, 1, 16)
64            tf.keras.layers.Flatten(),    # (batch, 224)
65            tf.keras.layers.Dense(16, activation="relu"),    # (batch, 16)
66            tf.keras.layers.Dropout(0.1),    # (batch, 16)
67            tf.keras.layers.Dense(4, activation="softmax")    # (batch, 4)
68    ])
69    model_path = os.path.join("./netmodels", "CNN")
70    print("Built CNN.")
71    if not os.path.exists(model_path):
72        os.makedirs(model_path)
73    model.load_weights("./netmodels/CNN/weights.h5")
74    return model, model_path
75
76
77def build_lstm(seq_length):
78    """Builds an LSTM in Keras."""
79    model = tf.keras.Sequential([
80            tf.keras.layers.Bidirectional(
81                    tf.keras.layers.LSTM(22),
82                    input_shape=(seq_length, 3)),    # output_shape=(batch, 44)
83            tf.keras.layers.Dense(4, activation="sigmoid")    # (batch, 4)
84    ])
85    model_path = os.path.join("./netmodels", "LSTM")
86    print("Built LSTM.")
87    if not os.path.exists(model_path):
88        os.makedirs(model_path)
89    return model, model_path
90
91
92def load_data(train_data_path, valid_data_path, test_data_path, seq_length):
93    data_loader = DataLoader(
94            train_data_path, valid_data_path, test_data_path, seq_length=seq_length)
95    data_loader.format()
96    return data_loader.train_len, data_loader.train_data, data_loader.valid_len, \
97            data_loader.valid_data, data_loader.test_len, data_loader.test_data
98
99
100def build_net(args, seq_length):
101    if args.model == "CNN":
102        model, model_path = build_cnn(seq_length)
103    elif args.model == "LSTM":
104        model, model_path = build_lstm(seq_length)
105    else:
106        print("Please input correct model name.(CNN    LSTM)")
107    return model, model_path
108
109
110def train_net(
111        model,
112        model_path,    # pylint: disable=unused-argument
113        train_len,    # pylint: disable=unused-argument
114        train_data,
115        valid_len,
116        valid_data,
117        test_len,
118        test_data,
119        kind):
120    """Trains the model."""
121    calculate_model_size(model)
122    epochs = 50
123    batch_size = 64
124    model.compile(
125            optimizer="adam",
126            loss="sparse_categorical_crossentropy",
127            metrics=["accuracy"])
128    if kind == "CNN":
129        train_data = train_data.map(reshape_function)
130        test_data = test_data.map(reshape_function)
131        valid_data = valid_data.map(reshape_function)
132    test_labels = np.zeros(test_len)
133    idx = 0
134    for data, label in test_data:    # pylint: disable=unused-variable
135        test_labels[idx] = label.numpy()
136        idx += 1
137    train_data = train_data.batch(batch_size).repeat()
138    valid_data = valid_data.batch(batch_size)
139    test_data = test_data.batch(batch_size)
140    model.fit(
141            train_data,
142            epochs=epochs,
143            validation_data=valid_data,
144            steps_per_epoch=1000,
145            validation_steps=int((valid_len - 1) / batch_size + 1),
146            callbacks=[tensorboard_callback])
147    loss, acc = model.evaluate(test_data)
148    pred = np.argmax(model.predict(test_data), axis=1)
149    confusion = tf.math.confusion_matrix(
150            labels=tf.constant(test_labels),
151            predictions=tf.constant(pred),
152            num_classes=4)
153    print(confusion)
154    print("Loss {}, Accuracy {}".format(loss, acc))
155    # Convert the model to the TensorFlow Lite format without quantization
156    converter = tf.lite.TFLiteConverter.from_keras_model(model)
157    tflite_model = converter.convert()
158
159    # Save the model to disk
160    open("model.tflite", "wb").write(tflite_model)
161
162    # Convert the model to the TensorFlow Lite format with quantization
163    converter = tf.lite.TFLiteConverter.from_keras_model(model)
164    converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
165    tflite_model = converter.convert()
166
167    # Save the model to disk
168    open("model_quantized.tflite", "wb").write(tflite_model)
169
170    basic_model_size = os.path.getsize("model.tflite")
171    print("Basic model is %d bytes" % basic_model_size)
172    quantized_model_size = os.path.getsize("model_quantized.tflite")
173    print("Quantized model is %d bytes" % quantized_model_size)
174    difference = basic_model_size - quantized_model_size
175    print("Difference is %d bytes" % difference)
176
177
178if __name__ == "__main__":
179    parser = argparse.ArgumentParser(allow_abbrev=False)
180    parser.add_argument("--model", "-m")
181    parser.add_argument("--person", "-p")
182    args = parser.parse_args()
183
184    seq_length = 128
185
186    print("Start to load data...")
187    if args.person == "true":
188        train_len, train_data, valid_len, valid_data, test_len, test_data = \
189                load_data("./person_split/train", "./person_split/valid",
190                                    "./person_split/test", seq_length)
191    else:
192        train_len, train_data, valid_len, valid_data, test_len, test_data = \
193                load_data("./data/train", "./data/valid", "./data/test", seq_length)
194
195    print("Start to build net...")
196    model, model_path = build_net(args, seq_length)
197
198    print("Start training...")
199    train_net(model, model_path, train_len, train_data, valid_len, valid_data,
200                        test_len, test_data, args.model)
201
202    print("Training finished!")
203