1import cmsisdsp as dsp 2import cmsisdsp.fixedpoint as f 3 4import numpy as np 5from scipy import signal 6import matplotlib.pyplot as plt 7import scipy.fft 8 9import colorama 10from colorama import init,Fore, Back, Style 11from numpy.testing import assert_allclose 12 13init() 14 15def printTitle(s): 16 print("\n" + Fore.GREEN + Style.BRIGHT + s + Style.RESET_ALL) 17 18def printSubTitle(s): 19 print("\n" + Style.BRIGHT + s + Style.RESET_ALL) 20 21 22def chop(A, eps = 1e-6): 23 B = np.copy(A) 24 B[np.abs(A) < eps] = 0 25 return B 26 27nb = 32 28signal = np.cos(2 * np.pi * np.arange(nb) / nb)*np.cos(0.2*2 * np.pi * np.arange(nb) / nb) 29 30ref=scipy.fft.rfft(signal) 31invref = scipy.fft.irfft(ref) 32 33print(f"ref length = {len(ref)}") 34print(ref) 35 36# Convert ref to CMSIS-DSP format 37referenceFloat=np.zeros(2*len(ref)) 38print(f"referenceFloat length = {len(referenceFloat)}") 39# Replace complex datatype by real datatype 40referenceFloat[0::2] = np.real(ref) 41referenceFloat[1::2] = np.imag(ref) 42# Copy Nyquist frequency value into first 43# sample.This is just a storage trick so that the 44# output of the RFFT has same length as input 45# It is legacy behavior that we need to keep 46# for backward compatibility but it is not 47# very pretty 48#referenceFloat[1] = np.real(ref[-1]) 49 50rifftQ31=dsp.arm_rfft_instance_q31() 51status=dsp.arm_rfft_init_q31(rifftQ31,nb,1,1) 52# Apply CMSIS-DSP scaling 53referenceQ31 = f.toQ31(referenceFloat / nb) 54 55resultQ31 = dsp.arm_rfft_q31(rifftQ31,referenceQ31) 56resultF = f.Q31toF32(resultQ31) 57 58print(f"resultF length = {len(resultF)}") 59assert_allclose(invref/nb,resultF,atol=1e-6) 60 61signalQ31 = f.toQ31(signal) 62rfftQ31=dsp.arm_rfft_instance_q31() 63status=dsp.arm_rfft_init_q31(rfftQ31,nb,0,1) 64resultQ31 = dsp.arm_rfft_q31(rfftQ31,signalQ31) 65print(len(resultQ31)) 66print(2*nb) 67resultF = f.Q31toF32(resultQ31) * nb 68 69def compareWithConjugatePart(r): 70 res = r[0::2] + 1j * r[1::2] 71 conjPart = res[nb:nb//2:-1].conj() 72 refPart = res[1:nb//2] 73 assert(np.equal(refPart , conjPart).all()) 74 75compareWithConjugatePart(resultF) 76 77res = resultF[0::2] + 1j * resultF[1::2] 78print(res) 79 80print(res[0:nb//2+1]) 81print(res[0:nb//2+1].shape)