c
c     fftc1.fx - 1D fast Fourier transform
c
c     Bailey's Six-Step Method, adapted from D. Bailey, 
c     "FFTs in External or Hierarchical Memory", 
c     The Journal of Supercomputing, 4, 23-35 (1990).
c
c     Comments and suggestions to Dave O'Hallaron, droh@cs.cmu.edu

cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc
c     Copyright (c) 1994 by Carnegie Mellon University
c
c     Permission to use, copy, modify, and distribute this software
c     for any purpose and without fee is hereby granted, provided that 
c     the above copyright notice appear in all copies.
cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc
#define N 512
#define NDV2 N/2
#define LOGN 9
#define ITERS 20

      program fftc1

c     array of length N1*N2 is represented by a 2D array of size N1 by N2
      complex a(N,N) 		! Input/Intermediate array
      complex b(N,N)		! Intermediate/Output arrary

c     precomputed constants
      integer brt(N)          ! bit reverse table for 1d fft's
      complex w(N/2)          ! twiddles for 1d fft's
      complex v(N,N)          ! twiddles for 2d fft (Scale Factor)

c     timing variables
      integer timer_stop
      external timer_stop
      integer t1clocks,t2clocks,t3clocks,s1clocks,f1clocks,f2clocks 

c     misc
      integer k
c     
c     precompute the input array and fft constants
c     
      template t(N)
      align a(j,k) with t(k)
      align b(j,k) with t(k)
      align v(j,k) with t(k)
      distribute t(block(64))

      call fx_sync()
      call gen_bit_reverse_table(brt)
      call gen_w_table(w)
      call gen_v_table(v)
c
c     begin the 1d fft program 
c
      t1clocks = 0
      t2clocks = 0
      t3clocks = 0
      s1clocks = 0
      f1clocks = 0
      f2clocks = 0
      do k=1,ITERS

         call dgen(a)      			! Generate input Data

         call timer_start()
         idxperm(b,2,1)	= a			! step 1: Transpose
         t1clocks = t1clocks + timer_stop()

         call timer_start()
         call cffts(b,brt,w)   			! step 2: Column FFT
         f1clocks = f1clocks + timer_stop()

         call timer_start()
         call scale(b,v)                	! step 3: Scale
         s1clocks = s1clocks + timer_stop()

         call timer_start()
         idxperm(a,2,1) = b			! step 4: Transpose
         t2clocks = t2clocks + timer_stop()

         call timer_start()
         call cffts(a,brt,w)   			! step 5: Column FFT 
         f2clocks = f2clocks + timer_stop()

         call timer_start() 
         idxperm(b,2,1) = a			! step 6: Transpose
         t3clocks = t3clocks + timer_stop()

         call chkmat(b)                 	! check the results

      enddo
      t1clocks = t1clocks/ITERS
      t2clocks = t2clocks/ITERS
      t3clocks = t3clocks/ITERS
      s1clocks = s1clocks/ITERS
      f1clocks = f1clocks/ITERS
      f2clocks = f2clocks/ITERS
      call prperf(t1clocks,t2clocks,t3clocks,s1clocks,f1clocks,f2clocks) ! Print Results
      end
       
c     
c     dgen - initialize the vector with a point source
c     
      subroutine dgen(a)
      complex a(N,N), cmplx
      intrinsic cmplx
      
      template t(N)
      align a(j,k) with t(k)
      distribute t(block(64))

      a = cmplx(0.,0.)

      a(1,NDV2+1) = cmplx(N*N,N*N)
      end

c
c     gen_bit_reverse_table - initialize bit reverse table
c       
c     Postcondition: br(i) = bit-reverse(i-1) + 1       
c
      subroutine gen_bit_reverse_table(brt)
      integer brt(N)
      integer i, j, k

      j = 1
      brt(1) = j
      do i = 2, N
         k = NDV2
 6       continue
         if (k .lt. j) then
            j = j - k
            k = k/2
            goto 6
         endif
         j = j + k
         brt(i) = j
      enddo
      return
      end

c
c     gen_w_table: generate powers of w.
c
c           postcondition: w(i) = w**(i-1)            
c
      subroutine gen_w_table(w)
      complex w(NDV2), cmplx
      integer i
      real pi
      real wr,wi
      real ptr, pti
      real aimag, real, cos, sin
      intrinsic aimag, real, cos, sin, cmplx
      
      pi = 3.141592653589793

      wr = cos(pi/NDV2)
      wi = -sin(pi/NDV2)
      w(1) = cmplx(1.0,0.0)
      ptr = 1.0
      pti = 0.0
      do i = 2,NDV2
         w(i) = cmplx(ptr*wr - pti*wi, ptr*wi + pti*wr)
         ptr = real(w(i))
         pti = aimag(w(i))
      enddo
      return
      end

c
c     gen_v_table - gen 2d twiddle factors
c
      subroutine gen_v_table(v)
      complex v(N,N), wn, cmplx
      real pi
      integer k,j
      real cos, sin
      intrinsic cos, sin, cmplx
      
      template t(N)
      align v(j,k) with t(k)
      distribute t(block(64))

      pi = 3.141592653589793
      wn = cmplx(cos((2.0*pi)/(N*N)),-sin((2.0*pi)/(N*N)))
      pdo k = 1,N
      pin v(:,k)
      pout v(:,k)
         do j = 1,N
            v(j,k) = cmplx(wn**((j-1)*(k-1)))
         enddo
      endpdo
      end

c
c     scale - multiply each array element by corresponding element of v
c      
      subroutine scale(a,v)
      complex a(N,N),v(N,N)

      integer j,k

      template t(N)
      align a(j,k) with t(k)
      align v(j,k) with t(k)
      distribute t(block(64))

      pdo k = 1,N
      pin a(:,k), v(:,k)
      pout a(:,k)
         do j = 1,N
             a(j,k) = a(j,k)*v(j,k)
         enddo
      endpdo
      end

c     
c     cffts - perform a 1d fft on each column 
c     
      subroutine cffts(a,brt,W)
      integer brt(N)
      complex a(N,N),w(NDV2)
      integer col

      template t(N)
      align a(j,k) with t(k)
      distribute t(block(64))

      pdo col=1,N
      pin a(:,col)
      pout a(:,col)
         call fft(a(1,col),brt,w)
      endpdo
      return
      end
     
c
c     Fast Fourier Transform 
c     1D in-place complex-complex decimation-in-time Cooley-Tukey
c
      subroutine fft(a,brt,W)
      real a(2,N), W(2,NDV2)
      integer brt(N)

      integer i,j
      integer powerOfW
      integer sPowerOfW
      integer ijDiff
      integer stage
      integer stride
      integer first
      real pwr,pwi, tr, ti
      real ir, ii, jr, ji

      nocheck

c     
c     bit reverse step
c     
      do i = 1,N
         j = brt(i)
         if (i .lt. j) then
            tr = a(1,j)
            ti = a(2,j)
            a(1,j) = a(1,i)
            a(2,j) = a(2,i)
            a(1,i) = tr
            a(2,i) = ti
         endif
      enddo
         
c     
c     butterfly computations
c     
      ijDiff = 1
      stride = 2
      sPowerOfW = NDV2
      do stage = 1,LOGN
c     Invariant: stride = 2 ** stage            
c     Invariant: ijDiff = 2 ** (stage - 1)
         
         first = 1
         do powerOfW = 1, NDV2, sPowerOfW
            pwr = W(1,powerOfW)
            pwi = W(2,powerOfW)
            
c     Invariant: pwr + sqrt(-1)*pwi = W**(powerOfW - 1)

            do i = first, N, stride
               j = i + ijDiff
               jr = a(1,j)
               ji = a(2,j)
               ir = a(1,i)
               ii = a(2,i)
               tr = jr*pwr - ji*pwi
               ti = jr*pwi + ji*pwr
               a(1,j) = ir - tr
               a(2,j) = ii - ti
               a(1,i) = ir + tr
               a(2,i) = ii + ti
            enddo

            first = first + 1
         enddo
            
         ijDiff = stride
         stride = stride * 2
         sPowerOfW = sPowerOfW / 2
      enddo

      return
      end
 
c     
c     chkmat - check the output matrix for correctness
c     
      subroutine chkmat(a)
      complex a(N,N)

      integer j,k
      integer sign,errors
      real epsilon, aimag, real
      intrinsic aimag, real

      template t(N)
      align a(j,k) with t(k)
      distribute t(block(64))

      epsilon = 0.0001
      errors = 0
      pdo k=1,N
      pin a(:,k)
       sign = 1
       do j=1,N
        if (real(a(j,k)) .gt. N*N*sign+epsilon) errors = errors + 1
        if (real(a(j,k)) .lt. N*N*sign-epsilon) errors = errors + 1
        if (aimag(a(j,k)) .gt. N*N*sign+epsilon) errors = errors + 1
        if (aimag(a(j,k)) .lt. N*N*sign-epsilon) errors = errors + 1
        sign = sign * (-1)
       enddo
      endpdo
      if (errors .gt. 0) then
        print *, 'errors = ', errors
      endif
      end
     
c
c     prperf - print FFT performance
c
      subroutine prperf(t1clks,t2clks,t3clks,s1clks,f1clks,f2clks) 
      integer t1clks,t2clks,t3clks,s1clks,f1clks,f2clks 

      real tsecs,ssecs,fsecs,secs,flops,mflops,eff
      integer tpercent,spercent,fpercent

      tsecs = (t1clks+t2clks+t3clks)/2.e7
      ssecs = s1clks/2.e7
      fsecs = (f1clks+f2clks)/2.e7
      secs  = tsecs+ssecs+fsecs

      tpercent = (tsecs/secs)*100
      spercent = (ssecs/secs)*100
      fpercent = (fsecs/secs)*100
      
      flops = N*N*(2*LOGN)*5
      mflops = (flops/1000000.0)/secs
      eff = (fsecs/secs)*100

      print *, '+++++++++++++++++++++++'
      print *, '1D FFT: Points = ', N*N 
      print *, '+++++++++++++++++++++++'
      print *, 'Time per input vector:'
      print *, " "
      print *, 'Step 1: transpose    : ',t1clks/2.e7, 'sec  '
      print *, 'Step 2: fft          : ',f1clks/2.e7, 'sec  '
      print *, 'Step 3: scale        : ',s1clks/2.e7, 'sec  '
      print *, 'Step 4: transpose    : ',t2clks/2.e7, 'sec  '
      print *, 'Step 5: fft          : ',f2clks/2.e7, 'sec  '
      print *, 'Step 6: transpose    : ',t3clks/2.e7, 'sec  '
      print *, " "
      print *, 'total transpose      : ',tsecs, 'sec  ', tpercent, '%'
      print *, 'total scale          : ',ssecs, 'sec  ', spercent, '%'
      print *, 'total ffts           : ',fsecs, 'sec  ', fpercent, '%'
      print *, 'total(s)             : ',secs, 'sec'
      print *, " "
      print *, 'mflop/s              : ',mflops
      print *, 'efficiency           : ',eff
      print *, " "
      end


