1# Lint as: python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#         http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16# pylint: disable=g-bad-import-order
17
18"""Data augmentation that will be used in data_load.py."""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import random
25
26import numpy as np
27
28
29def time_wrapping(molecule, denominator, data):
30    """Generate (molecule/denominator)x speed data."""
31    tmp_data = [[0
32                             for i in range(len(data[0]))]
33                            for j in range((int(len(data) / molecule) - 1) * denominator)]
34    for i in range(int(len(data) / molecule) - 1):
35        for j in range(len(data[i])):
36            for k in range(denominator):
37                tmp_data[denominator * i +
38                                 k][j] = (data[molecule * i + k][j] * (denominator - k) +
39                                                    data[molecule * i + k + 1][j] * k) / denominator
40    return tmp_data
41
42
43def augment_data(original_data, original_label):
44    """Perform data augmentation."""
45    new_data = []
46    new_label = []
47    for idx, (data, label) in enumerate(zip(original_data, original_label)):    # pylint: disable=unused-variable
48        # Original data
49        new_data.append(data)
50        new_label.append(label)
51        # Sequence shift
52        for num in range(5):    # pylint: disable=unused-variable
53            new_data.append((np.array(data, dtype=np.float32) +
54                                             (random.random() - 0.5) * 200).tolist())
55            new_label.append(label)
56        # Random noise
57        tmp_data = [[0 for i in range(len(data[0]))] for j in range(len(data))]
58        for num in range(5):
59            for i in range(len(tmp_data)):    # pylint: disable=consider-using-enumerate
60                for j in range(len(tmp_data[i])):
61                    tmp_data[i][j] = data[i][j] + 5 * random.random()
62            new_data.append(tmp_data)
63            new_label.append(label)
64        # Time warping
65        fractions = [(3, 2), (5, 3), (2, 3), (3, 4), (9, 5), (6, 5), (4, 5)]
66        for molecule, denominator in fractions:
67            new_data.append(time_wrapping(molecule, denominator, data))
68            new_label.append(label)
69        # Movement amplification
70        for molecule, denominator in fractions:
71            new_data.append(
72                    (np.array(data, dtype=np.float32) * molecule / denominator).tolist())
73            new_label.append(label)
74    return new_data, new_label
75