1# Lint as: python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#         http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16# pylint: disable=g-bad-import-order
17
18"""Test for data_augmentation.py."""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import unittest
25
26import numpy as np
27
28from data_augmentation import augment_data
29from data_augmentation import time_wrapping
30
31
32class TestAugmentation(unittest.TestCase):
33
34    def test_time_wrapping(self):
35        original_data = np.random.rand(10, 3).tolist()
36        wrapped_data = time_wrapping(4, 5, original_data)
37        self.assertEqual(len(wrapped_data), int(len(original_data) / 4 - 1) * 5)
38        self.assertEqual(len(wrapped_data[0]), len(original_data[0]))
39
40    def test_augment_data(self):
41        original_data = [
42                np.random.rand(128, 3).tolist(),
43                np.random.rand(66, 2).tolist(),
44                np.random.rand(9, 1).tolist()
45        ]
46        original_label = ["data", "augmentation", "test"]
47        augmented_data, augmented_label = augment_data(original_data,
48                                                                                                     original_label)
49        self.assertEqual(25 * len(original_data), len(augmented_data))
50        self.assertIsInstance(augmented_data, list)
51        self.assertEqual(25 * len(original_label), len(augmented_label))
52        self.assertIsInstance(augmented_label, list)
53        for i in range(len(original_label)):    # pylint: disable=consider-using-enumerate
54            self.assertEqual(augmented_label[25 * i], original_label[i])
55
56
57if __name__ == "__main__":
58    unittest.main()
59