1import cmsisdsp as dsp
2import cmsisdsp.fixedpoint as f
3
4import numpy as np
5from scipy import signal
6import matplotlib.pyplot as plt
7import scipy.fft
8
9import colorama
10from colorama import init,Fore, Back, Style
11from numpy.testing import assert_allclose
12
13init()
14
15def printTitle(s):
16    print("\n" + Fore.GREEN + Style.BRIGHT +  s + Style.RESET_ALL)
17
18def printSubTitle(s):
19    print("\n" + Style.BRIGHT + s + Style.RESET_ALL)
20
21
22def chop(A, eps = 1e-6):
23    B = np.copy(A)
24    B[np.abs(A) < eps] = 0
25    return B
26
27nb = 32
28signal = np.cos(2 * np.pi * np.arange(nb) / nb)*np.cos(0.2*2 * np.pi * np.arange(nb) / nb)
29
30ref=scipy.fft.rfft(signal)
31invref = scipy.fft.irfft(ref)
32
33print(f"ref length = {len(ref)}")
34print(ref)
35
36# Convert ref to CMSIS-DSP format
37referenceFloat=np.zeros(2*len(ref))
38print(f"referenceFloat length = {len(referenceFloat)}")
39# Replace complex datatype by real datatype
40referenceFloat[0::2] = np.real(ref)
41referenceFloat[1::2] = np.imag(ref)
42# Copy Nyquist frequency value into first
43# sample.This is just a storage trick so that the
44# output of the RFFT has same length as input
45# It is legacy behavior that we need to keep
46# for backward compatibility but it is not
47# very pretty
48#referenceFloat[1] = np.real(ref[-1])
49
50rifftQ31=dsp.arm_rfft_instance_q31()
51status=dsp.arm_rfft_init_q31(rifftQ31,nb,1,1)
52# Apply CMSIS-DSP scaling
53referenceQ31 = f.toQ31(referenceFloat / nb)
54
55resultQ31 = dsp.arm_rfft_q31(rifftQ31,referenceQ31)
56resultF = f.Q31toF32(resultQ31)
57
58print(f"resultF length = {len(resultF)}")
59assert_allclose(invref/nb,resultF,atol=1e-6)
60
61signalQ31 = f.toQ31(signal)
62rfftQ31=dsp.arm_rfft_instance_q31()
63status=dsp.arm_rfft_init_q31(rfftQ31,nb,0,1)
64resultQ31 = dsp.arm_rfft_q31(rfftQ31,signalQ31)
65print(len(resultQ31))
66print(2*nb)
67resultF = f.Q31toF32(resultQ31) * nb
68
69def compareWithConjugatePart(r):
70    res = r[0::2] + 1j * r[1::2]
71    conjPart = res[nb:nb//2:-1].conj()
72    refPart = res[1:nb//2]
73    assert(np.equal(refPart , conjPart).all())
74
75compareWithConjugatePart(resultF)
76
77res = resultF[0::2] + 1j * resultF[1::2]
78print(res)
79
80print(res[0:nb//2+1])
81print(res[0:nb//2+1].shape)