/*
 * Copyright (c) 1997 Massachusetts Institute of Technology
 *
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the
 * "Software"), to use, copy, modify, and distribute the Software without
 * restriction, provided the Software, including any modified copies made
 * under this license, is not distributed for a fee, subject to
 * the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE MASSACHUSETTS INSTITUTE OF TECHNOLOGY BE LIABLE
 * FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
 * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
 * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 *
 * Except as contained in this notice, the name of the Massachusetts
 * Institute of Technology shall not be used in advertising or otherwise
 * to promote the sale, use or other dealings in this Software without
 * prior written authorization from the Massachusetts Institute of
 * Technology.
 *
 */

#include <stdlib.h>
#include <stdio.h>
#include <stdarg.h>
#include <string.h>
#include <math.h>

#include "bench_utils.h"

char which_fft[80] = "";

short bench_echo_dat = 1;
FILE *bench_log_file = 0,*bench_dat_file = 0;

int bench_num_transforms = 0, *bench_num_sizes = 0, bench_transform_index = 0;
int bench_t_index_real = 0;
double *bench_norm_avgs = 0, *bench_results = 0;

int check_prime_factors(int n, int maxprime)
/* returns 1 of the maximum prime factor of n is <= maxprime, and 0
   otherwise. */
{
     int primes[30] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 
			37, 41, 43, 47, 53, 59, 61, 67, 71, 73,
			79, 83, 89, 97, 101, 103, 107, 109, 113 };
     int i;

     if (maxprime > primes[30-1]) {
	  printf("\nBUG in check_prime_factors: maxprime is too big (%d)!\n",
		 maxprime);
	  exit(1);
     }
     if (n == 0)
	  return 1;
     for (i = 0; n > 1 && i < 30 && primes[i] <= maxprime; ++i) {
	  while (n % primes[i] == 0)
	       n /= primes[i];
     }
     return (n == 1);
}

void do_standard_fft(FFTW_COMPLEX *arr, int rank, int *size, int isign,
		     int reim_alt)
{
     int size_reverse[20]; /* GO expects array in Fortran order */
     int i, prod, N = 1;
     int go_fft(FFTW_REAL *a, FFTW_REAL *b, 
		int ntot, int n, int nspan, int isn);

     for (i = 0; i < rank; ++i) {
	  if (!check_prime_factors(size[i],23)) {
	       printf("\nSorry, the FFT we use to check for correctness\n"
		      "(Singleton) can't handle prime factors > 23.\n");
	       exit(1);
	  }
	  size_reverse[rank - 1 - i] = size[i];
	  N *= size[i];
     }
     
     /* FFT using GO */
     prod = 1;
     for (i = 0; i < rank; ++i) {
	  prod *= size_reverse[i];
	  if (reim_alt) /* alternating real/imag parts */
	       go_fft((FFTW_REAL *)arr, ((FFTW_REAL *)arr) + 1,
		      N, size_reverse[i], prod, 2 * isign);
	  else /* real/imag parts separate */
	       go_fft((FFTW_REAL *)arr, ((FFTW_REAL *)arr) + N,
		      N, size_reverse[i], prod, isign);
     }
}

void bench_print_array(FFTW_COMPLEX *a, int N)
{
     int i;
     for (i = 0; i < N; ++i)
	  fprintf(stderr,"   [%d] = (%f,%f)", i, a[i].re, a[i].im);
     fprintf(stderr,"\n");
}

void rbench_print_array(FFTW_REAL *a, int N)
{
     int i;
     for (i = 0; i < N; ++i)
	  fprintf(stderr,"   [%d] = %f", i, a[i]);
     fprintf(stderr,"\n");
}

void rbench_init_array(FFTW_REAL *arr, int N)
/* Initialize array arr of length N to zero. */
{
     int i;
     for (i = 0; i < N; ++i)
	  arr[i] = 0.0;
}

void bench_init_array(FFTW_COMPLEX *arr, int N)
{
     rbench_init_array((FFTW_REAL*)arr, 2*N);
}

void rbench_init_array_for_check(FFTW_REAL *arr, int N)
/* Initialize array arr of length N to "random" values
   suitable for checking the accuracy of an FFT but
   reproducible. */
{
     int i;

     srand(3);
     for (i = 0; i < N; ++i)
	  arr[i] = rand() * 1.0e-9;
}

void bench_init_array_for_check(FFTW_COMPLEX *arr, int N)
{
     rbench_init_array_for_check((FFTW_REAL*)arr, 2*N);
}

double rbench_check_array(FFTW_REAL *arr, int N, double scale)
/* Returns mean error if array arr of length N.  arr[i] * scale
   should be the same as what bench_init_array_for_check produces;
   the mean error is the average of the absolute value of the
   differences. */
{
     double sum = 0.0, err;
     int i, n_errs_printed = 0;

     srand(3);

     for (i = 0; i < N; ++i) {
	  FFTW_REAL correct;

	  correct = rand() * 1.0e-9;

	  err = fabs(scale*arr[i] - correct);
	  err *= 2.0;
	  err /= fabs(scale*arr[i]) + fabs(correct) + 1.0e-10;
	  sum += err;
	  if (err > MAX_BENCH_ERROR*10 && n_errs_printed++ <= 10)
	       log_printf("\nFound big error at %d: \n"
			  "  %g*%g = %g != %g\n", i,
			  arr[i],scale,
			  scale*arr[i],
			  correct);
     }
     return (sum / N);
}

double bench_check_array(FFTW_COMPLEX *arr, int N, double scale)
{
     return rbench_check_array((FFTW_REAL*)arr, 2*N, scale);
}

void bench_conjugate_array(FFTW_COMPLEX *arr, int n, short reim_alt)
/* Complex-conjugates the array arr[n].  reim_alt has the same meaning
   as the parameter passed to DO_BENCHMARK_ND: if reim_alt is 1, then
   arr is interpreted as a normal array of FFTW_COMPLEX numbers.  If
   reim_alt is zero, then arr is cast to an array of 2*n FFTW_REAL
   numbers--the first n numbers are the real parts (and are untouched)
   and the second n numbers are the imaginary parts (and are negated). */
{
     int i;

     if (reim_alt)
	  for (i = 0; i < n; ++i)
	       c_im(arr[i]) = -c_im(arr[i]);
     else {
	  FFTW_REAL *arr_im = ((FFTW_REAL *)arr) + n;
	  for (i = 0; i < n; ++i)
	       arr_im[i] = -arr_im[i];
     }
}

void bench_copy_array(FFTW_COMPLEX *from_arr, FFTW_COMPLEX *to_arr, int n)
{
     if (from_arr != to_arr) {
	  int i;
	  for (i = 0; i < n; ++i)
	       to_arr[i] = from_arr[i];
     }
}

/* The following procedures operate on global variables (ugh)
   defined in bench_utils.h. */

void compute_normalized_averages(void)
{
     bench_fft_data *cur = fft_data_top, *best = NULL;

     while (cur) {
	  if (!best || cur->cur_mflops > best->cur_mflops)
	       best = cur;
	  cur = cur->next;
     }

     if (best && best->cur_mflops != -1) {
	  log_printf("Fastest FFT is %s with %g mflops.\n",
		     best->name, best->cur_mflops);
	  
	  log_printf("Normalized results:\n");
	  for (cur = fft_data_top; cur; cur = cur->next) 
	       if (cur->cur_mflops != -1) {
		    double norm = cur->cur_mflops / best->cur_mflops;
		    
		    cur->norm_avg += norm;
		    log_printf(" %s: mflops = %g"
			       " (norm = %g, avg. of %d = %g)\n",
			       cur->name, cur->cur_mflops, norm,
			       cur->num_sizes, cur->norm_avg / cur->num_sizes);
	       }
     }
}

void output_normalized_averages(void)
{
     bench_fft_data *cur = fft_data_top;

     if (cur) {
	  dat_printf("Norm. Avg.");
	  while (cur) {
	       if (cur->num_sizes > 0)
		    dat_printf(", %g", cur->norm_avg / cur->num_sizes);
	       else
		    dat_printf(", ");
	       cur = cur->next;
	  }
	  dat_printf("\n");
     }
}

bench_fft_data *fft_data_top = 0, *fft_data_cur = 0;

void destroy_fft_data(void)
{
     bench_fft_data *cur = fft_data_top;

     while (cur) {
	  bench_fft_data *tmp;

	  fftw_free(cur->name);
	  tmp = cur;
	  cur = cur->next;
	  fftw_free(tmp);
     }
     fft_data_top = 0;
}

bench_fft_data *append_fft_data(const char *name)
{
     bench_fft_data *cur = fft_data_top, *prev;

     prev = cur;
     while (cur) {
	  prev = cur;
	  cur = cur->next;
     }
     cur = (bench_fft_data *) fftw_malloc(sizeof(bench_fft_data));

     cur->name = (char *) fftw_malloc(sizeof(char) * (strlen(name) + 1));
     strcpy(cur->name, name);
     
     cur->cur_mflops = 0.0;
     cur->norm_avg = 0.0;
     cur->num_sizes = 0;

     cur->next = 0;
     
     if (prev) {
	  prev->next = cur;
	  cur->index = prev->index + 1;
     }
     else {
	  fft_data_top = cur;
	  cur->index = 1;
     }

     return cur;
}

bench_fft_data *find_fft_data(const char *name)
{
     bench_fft_data *cur = fft_data_top;
     while (cur) {
	  if (!strcmp(cur->name, name))
	       return cur;
	  cur = cur->next;
     }
     return 0;
}

int count_ffts(void)
{
     bench_fft_data *cur = fft_data_top;
     int count = 0;
     while (cur) {
	  ++count;
          cur = cur->next;
     }
     return count;
}

short fft_enable_flag = 1;

void set_fft_enabled(short enabled)
{
     fft_enable_flag = fft_enable_flag && enabled;
}

short fft_skip_flag = 0;
char *fft_skip_message = 0;

void set_fft_skip(short skip, const char *message)
{
     fft_skip_flag = fft_skip_flag || skip;

     if (skip) {
	  if (fft_skip_message)
	       fftw_free(fft_skip_message);
	  fft_skip_message =
	       (char*) fftw_malloc(sizeof(char) * (strlen(message) + 1));
	  strcpy(fft_skip_message, message);
     }
}

short set_fft_name(const char *name, int for_computation)
{
     if (fft_enable_flag && (which_fft[0] == 0 || !strcmp(which_fft,name))) {
	  if (!for_computation) {
	       fft_data_cur = append_fft_data(name);

	       dat_printf(", %s",name);
	       log_printf("%d. %s\n", fft_data_cur->index, name);
	  }
	  else {
	       fft_data_cur = find_fft_data(name);
	       
	       if (!fft_data_cur) {
		    log_printf("ERROR!  %s not found!\n", name);
		    exit(EXIT_FAILURE);
	       }

	       log_printf("%d. %s: ", fft_data_cur->index, name);
	  }
	  fft_data_cur->cur_mflops = -1;

	  if (fft_skip_flag)
	       skip_benchmark(fft_skip_message);

	  return 1;
     }
     else {
	  fft_data_cur = 0;
          fft_skip_flag = 0;
	  fft_enable_flag = 1;
	  return 0;
     }
}

void output_mean_error(double mean_error, int for_check_only)
{
     if (!fft_data_cur)
	  return;
     if (for_check_only) {
	  log_printf(" (err=%0.1e)",mean_error);
	  if (mean_error > MAX_BENCH_ERROR) {
	       log_printf("\n\n%s gave wrong answer! (mean err = %g)\n",
			  fft_data_cur->name,mean_error);
	       printf("\n\n%s gave wrong answer! (mean err = %g)\n",
		      fft_data_cur->name,mean_error);
	       exit(1);
	  }
     }
     else {
	  log_printf(" mean fractional error = %e\n", mean_error);
	  dat_printf(", %e",mean_error);
     }
}

void output_results(double t, int iters, int real_N, double mflops_scale)
{
     double mflops;

     mflops = 5.0 * real_N * log((double) real_N) * ((double) iters)
	  / (log(2.0) * t * 1.0e6) * mflops_scale;

     log_printf("time/fft = %g us, mflops = %g",t * 1.0e6 / iters,
		mflops);

     dat_printf(", %g", mflops);

     if (fft_data_cur) {
	  fft_data_cur->cur_mflops = mflops;
	  fft_data_cur->num_sizes += 1;
     }
}

void skip_benchmark(const char *why)
{
     if (fft_data_cur) {
	  fft_data_cur->cur_mflops = -1;
	  fft_skip_flag = 0;
	  dat_printf(", ");
	  log_printf("Skipping %s: %s\n", fft_data_cur->name, why);
	  fft_data_cur = 0;
     }
}

void log_printf(const char *template, ...)
{
     if (bench_log_file) {
	  va_list ap;
     
	  va_start(ap, template);
	  vfprintf(bench_log_file, template, ap);
	  va_end(ap);
	  
	  fflush(bench_log_file);
     }
}

void dat_printf(const char *template, ...)
{
     if (bench_dat_file) {
	  va_list ap;
     
	  va_start(ap, template);
	  vfprintf(bench_dat_file, template, ap);
	  va_end(ap);
	  
	  fflush(bench_dat_file);
     }
     if (bench_echo_dat) {
	  va_list ap;
     
	  va_start(ap, template);
	  vfprintf(stdout, template, ap);
	  va_end(ap);

	  fflush(stdout);
     }
}

