1# Lint as: python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16# pylint: disable=g-bad-import-order 17 18"""Test for data_augmentation.py.""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import unittest 25 26import numpy as np 27 28from data_augmentation import augment_data 29from data_augmentation import time_wrapping 30 31 32class TestAugmentation(unittest.TestCase): 33 34 def test_time_wrapping(self): 35 original_data = np.random.rand(10, 3).tolist() 36 wrapped_data = time_wrapping(4, 5, original_data) 37 self.assertEqual(len(wrapped_data), int(len(original_data) / 4 - 1) * 5) 38 self.assertEqual(len(wrapped_data[0]), len(original_data[0])) 39 40 def test_augment_data(self): 41 original_data = [ 42 np.random.rand(128, 3).tolist(), 43 np.random.rand(66, 2).tolist(), 44 np.random.rand(9, 1).tolist() 45 ] 46 original_label = ["data", "augmentation", "test"] 47 augmented_data, augmented_label = augment_data(original_data, 48 original_label) 49 self.assertEqual(25 * len(original_data), len(augmented_data)) 50 self.assertIsInstance(augmented_data, list) 51 self.assertEqual(25 * len(original_label), len(augmented_label)) 52 self.assertIsInstance(augmented_label, list) 53 for i in range(len(original_label)): # pylint: disable=consider-using-enumerate 54 self.assertEqual(augmented_label[25 * i], original_label[i]) 55 56 57if __name__ == "__main__": 58 unittest.main() 59