#include "fgrepdefs.h"
/*********************  main control ************************/

main(argc,argv)
int argc; char *argv[];
{
  init_params(argv);
  read_inputs();
#ifdef hp
  if (displaying) display_init();
#endif
  if (continuing) iterate_snapshots();
  else init_weights();
  training();
#ifdef hp
  if (displaying) gclose(fildes);
#endif
  exit(0);
}


iterate_snapshots()
/* go through the saved snapshots, displaying them on the screen */
{
  register int i;
  int oldepoch, oldtesting, readfun(), randfun();
  float cume;

  fp=fopen(simufile,"r");
  read_params(fp);            /* we have to read in the beg. of file again */
  oldtesting=testing;
  testing=1;		      /* we don't want to change weights */
  oldepoch=0;
  iterate_weights(randfun, (-1.0), 2.0, 0.0, 1.0); /* just to update rand */
  while (fscanf(fp,"%d",&epoch)!=EOF) /* read the epoch */
    {
      /* update the shuffling and parameters */
      epoch++;
      startepoch=epoch;
      for(i=0; i<epoch-oldepoch; i++) shuffle();
      oldepoch=epoch;
      fscanf(fp, "%f", &cume); /* read the current average error */
      
      /* read the current word representations and weights */
      iterate_weights(readfun);
#ifdef hp
      if (displaying)
	{
	  init_deltas(0.0, 0);
	  if (epoch>0) display_cume(cume);
	  display_all_weights();
	  sentence(shuffletable[0]);
	  while (getchar()!='\n');  /* wait for return from the user */
	}
#endif
    }
  fclose(fp);
  /* get the next snapshot after where we are now */
  for (nextsnapshot=0; counter > snapshots[nextsnapshot]; nextsnapshot++){}
  testing=oldtesting;
}
      

training()
{
  register int senti;
  for(epoch=startepoch; counter<=phaseends[nphase-1]; epoch++)
    {
      get_current_params();
      init_deltas(0.0, 0);              /* error = 0 for the initial state */
      for(senti=0; senti<nsents; senti++)
	sentence(shuffletable[senti]);
      write_error(stdout);
      shuffle();
      if (counter >= snapshots[nextsnapshot])
	save_current();
    }
#ifdef hp
  if (displaying && epoch>startepoch)
    while (getchar()!='\n'); /* wait for return from the user */
#endif
}


sentence(senti)
register int senti;
{
  register int i,j;
  int *inpdataptr;
   
#if toks
  randomize_tokens();
#endif
  /* get the current input and teaching patterns from the lexicon */
  for (i=0; i<ninpas; i++)
    for(j=0; j<ninprep;j++)
      inprep[i][j]=words[inpnums[senti][i]].rep[j];
  for(i=0;i<noutas;i++)
    for(j=0;j<noutrep;j++)
      tchrep[i][j]=words[tchnums[senti][i]].rep[j];
#ifdef hp
  if (displaying) display_input(senti);
#endif
  propagate_and_display(senti);
  cumulate_error();
}

propagate_and_display(senti)
register int senti;
{
  register int i;
  forward_prop();
#ifdef hp
  if (displaying)
    {
      display_assembly(marg, hidy, hidrep, nhidrep);
      display_layer(noutas, outrep, outy, noutrep);
    }
#endif
  if (!testing) backward_prop(senti); /* when testing, don't change weights */
#ifdef hp
  if (displaying && !testing)
    {
      display_weight_layer(whoy,who,noutas,nhidrep,noutrep);
      display_weight_layer(wihy,wih, ninpas,nhidrep,ninprep);
      display_layer(ninpas,inprep,inpy,ninprep);
      for(i=0; i<ninpas; i++)
	if (inpnums[senti][i]>-1)
	  display_rep(inpnums[senti][i]);
    }
#endif
}


/*********************  initializations ******************************/

init_params(argv)
char *argv[1];
{
  register int i;
  int c;

  sprintf(simufile, "%s", argv[1]);
  fp=fopen(simufile,"r");
  read_params(fp);
  if((c=getc(fp))==EOF)
    continuing=0;
  else
    continuing=1;
  fclose(fp);

  words = wordarray+1;		/* blank word for empty display */
  init_deltas(0.0, 1);                  /* error = 0 for the initial state */
  startepoch=0;
  epoch=(-1);
  srand48(seed);			/* start random number sequence */
}

  
read_params(fp)
FILE *fp;
{
  char s[100];
  register int i;
  int starttime;

  /* simulation parameters */
  fscanf(fp,"%s", wordfile); fgets(s,99,fp);
  fscanf(fp,"%s", inputfile); fgets(s,99,fp);
  fscanf(fp,"%s", cmapfile); fgets(s,99,fp);
  fscanf(fp,"%s", sb_outdev); fgets(s,99,fp);
  fscanf(fp,"%s", sb_outdriver); fgets(s,99,fp);
  fscanf(fp,"%d %d", &nwordrep, &nhidrep); fgets(s,99,fp);
  fscanf(fp, "%d %d %d %d",&displaying,&testing,&seed,&nphase); fgets(s,99,fp);
  for(i=0; i<nphase; i++)
    fscanf(fp,"%d",&phaseends[i]); fgets(s,99,fp);
  for(i=0; i<nphase; i++)
    fscanf(fp,"%f",&etas[i]); fgets(s,99,fp);
  
  /* saving info */
  fscanf(fp,"%d", &snapshots[0]);
  for(i=0; i<maxsnaps && snapshots[i]<snapshotend; i++)
    fscanf(fp,"%d", &snapshots[i+1]);
  fgets(s,99,fp);
  nextsnapshot=0;
#if time_params
  starttime=time(0);
  for(i=0; i<nphase; i++)
    phaseends[i] += starttime;
  for(i=0; snapshots[i]<snapshotend; i++)
    if(snapshots[i]>-1) snapshots[i] += starttime;
#endif
  ninprep=noutrep=nwordrep;
}


read_inputs()
{
  char s[100];
  register int i,j;
  int c;

  /* read the words */
  fp=fopen(wordfile,"r");
  for(i=0; fscanf(fp,"%s", words[i].chars)!=EOF; i++);
  nwords=i;
  fclose(fp);

  /* read the number of necessary assemblies */
  fp=fopen(inputfile,"r");
  fscanf(fp, "%d %d", &ninpas, &noutas); fgets(s,99,fp);

  /* read the input sentences */
  for(i=0; (c=getc(fp))!=EOF; i++)
    {
      ungetc(c, fp);
      for(j=0; j<ninpas; j++)
	fscanf(fp,"%d",&inpnums[i][j]);
      for(j=0; j<noutas; j++)
	fscanf(fp,"%d",&tchnums[i][j]);
      fgets(s,99,fp);
      shuffletable[i]=i;
    }
  nsents=i;
  fclose(fp);
  shuffle();
}


init_weights()
{
  register int  i, j;
  int randfun();

  iterate_weights(randfun, (-1.0), 2.0, 0.0, 1.0);

  if (snapshots[0]==(-1))
    save_current();			/* save the initial state */
#ifdef hp
  if (displaying) display_all_weights();
#endif
}


init_deltas(suminit, numinit)
float suminit;
int numinit;
{
  deltasum=suminit;
  deltanum=numinit;
}
      

/*******************   backprop  ************************************/

forward_prop()
{
  register int i,j,k;
  float sigmoid();
  
  for(k=0; k<nhidrep; k++)
    {
      hidrep[k]=0.0;
      for(i=0; i<ninpas; i++)
	for(j=0; j<ninprep; j++)
	  hidrep[k] += inprep[i][j]*wih[i][j][k];
      hidrep[k]=sigmoid(hidrep[k]);
    }
  
  for(i=0; i<noutas; i++)
    for(j=0;j<noutrep; j++)
      {
	outrep[i][j]=0.0;
	for(k=0; k<nhidrep; k++)
	  outrep[i][j] += hidrep[k]*who[i][j][k];
	outrep[i][j] = sigmoid(outrep[i][j]);
      }
}      

backward_prop(senti)
int senti;
{
  register int i,j,k;
  float clip(), sig, fact, hidsumsig[maxrep], inpsumsig[maxinpas][maxrep];

  for(k=0; k<nhidrep; k++)
    hidsumsig[k]=0.0;
  for(i=0; i<ninpas; i++)
    for(j=0; j<ninprep; j++)
      inpsumsig[i][j]=0.0;

  /* output weights */
  for(i=0; i<noutas; i++)
    for(j=0; j<noutrep; j++)
      {
	sig = (tchrep[i][j]-outrep[i][j])*
	  outrep[i][j]*(1.0-outrep[i][j]);
	fact=eta*sig;
	for(k=0; k<nhidrep; k++)
	  {
	    hidsumsig[k] += sig*who[i][j][k];
	    who[i][j][k] += fact*hidrep[k];
	  }
      }

  /* input weights */
  for(k=0; k<nhidrep; k++)
    {
      sig = hidsumsig[k]*hidrep[k]*(1.0-hidrep[k]);
      fact = eta*sig;
      for(i=0; i<ninpas; i++)
	for(j=0; j<ninprep; j++)
	  {
	    inpsumsig[i][j] += sig*wih[i][j][k];
	    wih[i][j][k] += fact*inprep[i][j];
	  }
    }
  
  /* representations */
  for(i=0; i<ninpas; i++)
    if (inpnums[senti][i]>-1)
      for(j=0; j<ninprep; j++)
	words[inpnums[senti][i]].rep[j] = inprep[i][j] =
	  clip(inprep[i][j] +eta*inpsumsig[i][j]);
}


cumulate_error()
/* for computing the average error at the output for each epoch */
{
  register int i,j;
  for(i=0; i<noutas; i++)
    for(j=0; j<noutrep; j++)
      {
	deltanum++;
	deltasum += fabs(tchrep[i][j]-outrep[i][j]);
      }
#ifdef hp
  if(displaying) display_cume( deltasum/deltanum);
#endif
}

/*********************** I/O etc functions *****************/ 

iterate_weights(dofun,par1,par2,par3,par4)
int (*dofun)();        /* this function is either rand, read or write */
float par1,par2,par3,par4;
{
  register int i,j,k;

    for(i=0; i<nwords; i++)
      for(j=0; j<nwordrep; j++)
        (*dofun)(&words[i].rep[j],par3,par4);
  
  for(i=0; i<ninpas; i++)
    for(j=0; j<ninprep; j++)
      for(k=0; k<nhidrep; k++)
	(*dofun)(&wih[i][j][k],par1,par2);
  
  for(i=0; i<noutas; i++)
    for(j=0; j<noutrep; j++)
      for(k=0; k<nhidrep; k++)
	(*dofun)(&who[i][j][k],par1,par2);
}  

#if toks
/* use this when cloning the synonymous words */
randomize_tokens()
{
  register int i,j;

  for(j=0; j<2; j++)
    {
      randfun(&words[6].rep[j], 0.0, 1.0);
      randfun(&words[13].rep[j], 0.0, 1.0);
      randfun(&words[27].rep[j], 0.0, 1.0);
      randfun(&words[19].rep[j], 0.0, 1.0);
      randfun(&words[4].rep[j], 0.0, 1.0);
      randfun(&words[7].rep[j], 0.0, 1.0);
      randfun(&words[8].rep[j], 0.0, 1.0);
    }
  
#ifdef hp
  if (displaying)
    for(i=0; i<nwords; i++)
      display_rep(i);
#endif
}
/* use this when cloning all words */
/*randomize_tokens()
{
  register int i,j;

  for (i=4; i<nwords; i++)
    for(j=0; j<2; j++)
      randfun(&words[i].rep[j], 0.0, 1.0);
  
}*/
#endif

readfun(place)
float *place;
{
  fscanf(fp,"%f", place);
} 

writefun(place)
float *place;
{
  fprintf(fp,"%f\n", *place);
} 

randfun(place,par1,par2)
float *place, par1,par2;
{
  *place = par1+par2*drand48();
} 


get_current_params()
/* update the necessary simulation parameters */
{
  register int phase;
  for(phase=nphase-1;                   /* phase */
      phase>0 && phaseends[phase-1]>=counter;
      phase--) {}
  eta=etas[phase];			/* eta */
/*  printf("epoch=%d, phase=%d, eta=%f, time=%d\n", epoch,phase,eta,time(0));*/
}

save_current()
/* save a snapshot */
{
  int writefun();
  fp=fopen(simufile,"a");
  write_error(fp);
  iterate_weights(writefun);
  fclose(fp);
  nextsnapshot++;
}

      
write_error(fp)
FILE *fp;
{
  if (fp==stdout)
    fprintf(fp, "Epoch, error: ");
  fprintf(fp,"%d ", epoch);
  if (deltanum>0)
    fprintf(fp, " %f", deltasum/deltanum);
  fprintf(fp, "\n");
}


shuffle()
/* shuffle the inputs at each epoch so that the system won't learn
the order of inputs */
{
  register int i;
  int temp1,temp2,temp3;
  for(i=0; i<nsents; i++)
    {
      temp1=lrand48()%nsents;
      temp2=lrand48()%nsents;
      temp3=shuffletable[temp1];
      shuffletable[temp1]=shuffletable[temp2];
      shuffletable[temp2]=temp3;
    }
}

/***********************  math stuff **************************/

float f01rnd()
{
  /* random float between 0.0 and 1.0 */
  return (drand48());
}

float clip(activity)
float activity;
{
  /* cut activity within 0 and 1 */
  if (activity<0.0) return(0.0);
  else if (activity > 1.0) return(1.0);
  else return(activity);
}

float sigmoid(activity)
float activity;
{
  /* transform the activity to a sigmoid response between 0 and 1 */
  return(1.0/(1.0+exp(-activity)));
}

float fmax(a,b)
float a, b;
{
  return((a > b) ? a : b);
}

int imax(a,b)
int a, b;
{
  return((a > b) ? a : b);
}
