#include <iostream>
#include <vector>
#include <list>
#include <fstream>
#include <algorithm>
#include <cmath>
#include <utility>
#include <complex>
#include "fftw3.h"
#include <string>
#include <cmath>

#include "fftInterface.h"


using namespace std;
using namespace Eigen;


fftInterface::fftInterface() {
;
}

//*****************************************************************************************
//getFourier(): This function returns the fourier transform of the input Electric Field by calling FFTW library
//****************************************************************************************
void fftInterface::getFourier(const MatrixXcf &Efield, MatrixXcf &output){

int Nr = Efield.rows();
int Nc = Efield.cols();
int N = Nr*Nc;

MatrixXcf EfieldShifted(Nr, Nc);
//phaseShiftField(Efield, -1, wavelength, theta, phi, xGrid, yGrid, EfieldShifted);
EfieldShifted = Efield;

fftw_complex *in, *out;
fftw_plan p;
in = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*N);
out = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*N);
p = fftw_plan_dft_2d(Nr, Nc, in, out, FFTW_FORWARD, FFTW_ESTIMATE);

//Write input data to in
for (int r = 0; r < Nr; r++){
   for (int c = 0; c < Nc; c++){
       int offset = r*Nc + c;
       in[offset][0] = real(EfieldShifted(r, c));
       in[offset][1] = imag(EfieldShifted(r, c));
  }
}

fftw_execute(p);

//Write out to output matrix
for (int r = 0; r < Nr; r++){
   for (int c = 0; c < Nc; c++){
      int offset = r*Nc + c;
      complex<float> Fo(out[offset][0], out[offset][1]);
      output(r, c) = Fo;
  }
}

fftw_destroy_plan(p);
fftw_free(in);
fftw_free(out);
return;
}

//*****************************************************************************************
//getDCT(): This function returns the fourier transform of the input Electric Field by calling FFTW library
//****************************************************************************************
void fftInterface::getDCT(const MatrixXcf &Efield, MatrixXf &output){

int Nr = Efield.rows();
int Nc = Efield.cols();
int N = Nr*Nc;

MatrixXcf EfieldShifted(Nr, Nc);
//phaseShiftField(Efield, -1, wavelength, theta, phi, xGrid, yGrid, EfieldShifted);
EfieldShifted = Efield;

double *in, *out;
fftw_plan p;
in = (double*) fftw_malloc(sizeof(double)*N);
out = (double*) fftw_malloc(sizeof(double)*N);
p = fftw_plan_r2r_2d(Nr, Nc, in, out, FFTW_REDFT10, FFTW_REDFT10, FFTW_ESTIMATE);

//Write input data to in
for (int r = 0; r < Nr; r++){
   for (int c = 0; c < Nc; c++){
       int offset = r*Nc + c;
       in[offset] = real(EfieldShifted(r, c));
  }
}

fftw_execute(p);

//Write out to output matrix
for (int r = 0; r < Nr; r++){
   for (int c = 0; c < Nc; c++){
      int offset = r*Nc + c;
      output(r, c) = out[offset];
  }
}

fftw_destroy_plan(p);
fftw_free(in);
fftw_free(out);
return;
}
//*************************************************************************************************
//getInverseFourier(): This function takes frequency domain matrix as input and returns electric field by interfacing with FFTW
//*************************************************************************************************
void fftInterface::getInverseFourier(MatrixXcf &input, MatrixXcf &output){


fftw_complex *in, *out;
fftw_plan p;
int Nr = input.rows();
int Nc = input.cols();
int N =Nr*Nc;
in = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*N);
out = (fftw_complex*) fftw_malloc(sizeof(fftw_complex)*N);
p = fftw_plan_dft_2d(Nr, Nc, in, out, FFTW_BACKWARD, FFTW_ESTIMATE);

//Write input data to in
for (int r = 0; r < Nr; r++){
   for (int c = 0; c < Nc; c++){
       int offset = r*Nc + c;
       double E[2];
       E[0] = real(input(r, c)); E[1] = imag(input(r, c));
       in[offset][0] = E[0]; in[offset][1] = E[1];
  }
}
fftw_execute(p);
//Write output matrix
MatrixXcf outputShifted(Nr, Nc);
for (int r = 0; r < Nr; r++){
   for (int c = 0; c < Nc; c++){
       int offset = r*Nc + c;
       complex<float> E(out[offset][0], out[offset][1]);
       outputShifted(r, c) = E/((float)N);
  }
}

fftw_destroy_plan(p);
fftw_free(in);
fftw_free(out);

output = outputShifted;
//phaseShiftField(outputShifted, 1, wavelength, theta, phi, xGrid, yGrid, output);

return;

}
