/***********************************************************************
  Weighted majority implementation.
  Permission to distribute is granted. (C) Copyright Avrim Blum 1995

  This program does a weighted majority, using one algorithm for each tuple of
  features.  Each algorithm has a truth-table where each entry is a situation
  followed by the results of the last AMT times that situation was seen.
  Each algorithm predicts based on what outcome appears most in the truth-table
  entry.  If it is a new entry, then it predicts a global default.  The global
  default is the majority over the most recent DEFAULT_AMT outcomes seen.

  This code hasn't been cleaned up as much as varval_pairs.c

    Command line arguments allowed:
    first arg is feature set, second is what to predict,
    third is output file and fourth is data file.
    Fifth is when to stop.
    Sixth is 'a' for aggressive, 'l' for lowerbounding weights, or a float
    to set aggressive and the throw-out value.

    If you use command line args, runs in non-interactive mode.

    Note: this code is NOT optimized for most efficient use of "aggressive"
    mode.  In particular, we should store the algorithms in a LIST rather
    than an array so that we can snip them out of the list when they
    are pruned (rather than constantly skipping over them in the array)

*/
#include <stdio.h>
#include <math.h>
#include "header.h"

/* DEFINES  */
#define AMT 5       /* amount of past data stored per slot.  We then take 
		       majority vote, biased towards more recent data.    */
#define DEFAULT_AMT 7   /*  amount of data stored to compute default.  Default
			    is computed by majority vote over past DEFAULT_AMT
			    examples.  */
#define beta 0.5       /* this is 1 - the WM multiplier  */

#define NPRED 3        /* we print out to user the top NPRED predictions */

#define HASH_SIZE 101

#define THROW_OUT_VALUE (0.1/(num_algs*num_algs)) /* for aggressive case */

typedef struct state_list_struct  {
  char situation[TUPLE_SIZE][INPUT_LEN];
  int history[AMT]; /* last AMT correct answers for this situation*/
  int historysize;
  struct state_list_struct *next;
} state_list;

typedef state_list *state_table[HASH_SIZE];

/* globals */
FILE *fileptr, *summaryfp = NULL;
int num_algs;                 /* how many algs are there? */
int findex;                   /* which feature set are we using? */
int NUM_INPUTS;               /* how many inputs do we actually have? */

state_table alg_array[MAX_NUM_ALGS];    /* states of each algorithm */
char alg_name_array[MAX_NUM_ALGS][10];  /* "name" of each alg.      */
int alg_inputs_array[MAX_NUM_ALGS][TUPLE_SIZE]; /* tuples for each algorithm */
float throw_out_value;                  /* for AGGRESSIVE case */

char global_out_array[maxoutputs][MAX_FEATURES]; /* global array containing
						 names of the output values. */
int num_outputs = 0;

int default_index;               /* index in g_o_a of SINGLE, GLOBAL default */
int mistake_array[MAX_NUM_ALGS], 
    our_mistakes,                /* total number of mistakes of WM */
    fewest_mistakes;             /* holds # of mistakes of "best" individual*/
float weight_array[MAX_NUM_ALGS];/* weights of all the algorithms */
int ignore_me[MAX_NUM_ALGS];     /* for aggressive case, ignore me? */
int prediction_array[MAX_NUM_ALGS]; /* has individual predictions */
example_t inputs[MAX_FEATURES];     /* holds list of input features */
example_t correct_ans;              /* holds correct answer */
int correct_index;                  /* index of correct_ans in global_out_array*/
int seen_so_far;                    /* amount of data seen so far */
int top_predictions[NPRED];      /* holds indices of top NPRED predictions */
float top_weights[NPRED];        /* holds corresponding weights */
int aggressive = 0;         /* 1 if throwing out algorithms */
int lowerbound = 0;         /* 1 if put lower bound on weight of algorithm:
			         use same bound as would use to throw out if aggr*/
int num_left;         /* number of algorithms still around (in aggressive case) */
int interactive = 1;     /* interactive? */
int OUTPUTTIME  = 1685;  /* when to output */

/* functions that return values */
state_list *lookup();


int main(int argc, char **argv)
{
  int r; char *p;
  init(argc, argv);   /* initialize. Read through first example name. */

  while ((r=read_current_example(fileptr, the_feature_set[findex].features,
				 the_feature_set[findex].num_features,
				 inputs)) != -1) {
    if (r <-1) {
      if (r==NORESULT) printf("no result in ");
      else printf("couldn't find all features in ");
      printf("example %s.\ngoing on to next example....\n",example_name);
      find_next_desired_event(fileptr);
      continue;
    }
    wm_predict();          /* predict */
    evaluate_result();
    pretty_print();        /* print out info, update number of mistakes */
    wm_update();      /* update algs and weights*/
    update_default(); /* update default value */
    
    find_next_desired_event(fileptr);  /* read up to next add/copy event */
  }
  return 0;
}


/* This runs the WM algorithm.  Currently takes weighted majority of all.
   Could change to, say, only look at those algs whose weights are at least
   1/2 of the average.

   Returns results in global "top_predictions"
 */
wm_predict()
{
  char *alg_input[TUPLE_SIZE];
  int i,j, result;
  float weight, max_weight;
  double votes[maxoutputs];           /* votes for each possible answer */

  for(j=0; j < num_outputs; ++j) votes[j] = 0.0;   /* start at 0 */

  /* run through algorithms and get predictions */
  for(i=0; i < num_algs; ++i) {
    if (ignore_me[i]) continue;
    for(j=0; j < TUPLE_SIZE; ++j) 
      alg_input[j] = inputs[alg_inputs_array[i][j]];  /* set up alg_input*/
    prediction_array[i] = alg_predict(alg_input,alg_array[i],default_index);
    votes[prediction_array[i]] += weight_array[i];    /* tally votes */
  }

  /* Now, find NPRED predictions */
  for(i=0; i < NPRED; ++i) {
    max_weight = -1.0;
    for(j=0; j < num_outputs; ++j) {
      if (votes[j] > max_weight) {
	max_weight = votes[j];
	top_predictions[i] = j;
	top_weights[i] = votes[j];
      }
    }
    if (max_weight == -1.0) /* didn't get any more */
      top_predictions[i] = -1;   /* set flag (could speed up if desired)*/
    else
      votes[top_predictions[i]] = -1.0;  /* so don't choose it again*/
  }
}


/* updates number of mistakes and the weights and the individual algorithms*/

wm_update()
{
  int i, j, min_mistake;
   char *alg_input[TUPLE_SIZE];

  min_mistake = fewest_mistakes + 1;
  for(i=0; i < num_algs; ++i) {
    if (ignore_me[i]) continue;
    if (lowerbound && weight_array[i] < throw_out_value) continue;
    if (correct_index != prediction_array[i]) {
      ++mistake_array[i];
      weight_array[i] = weight_array[i] * (1.0 - beta);
    }
    if (aggressive && weight_array[i] < throw_out_value && num_left > 4) {
      ignore_me[i] = 1;
      --num_left;
    }
    if (mistake_array[i] < min_mistake) min_mistake = mistake_array[i];

    /* update individual algorithms */
    for(j=0; j < TUPLE_SIZE; ++j) 
      alg_input[j] = inputs[alg_inputs_array[i][j]];  /* set up alg_input*/
    alg_update(alg_input,alg_array[i]);
  }

  /* Normalize the weight array so max is 1.0 */
  if (min_mistake > fewest_mistakes) {
    ++fewest_mistakes;
    for(i=0; i < num_algs; ++i) {
      if (ignore_me[i]) continue;
      weight_array[i] = weight_array[i] / (1.0 - beta);
    }
  }
}


/* compares two lists (arrays) of TUPLE_SIZE strings, 
   and output number of matches */
sitcmp(ar1,ar2)
     char ar1[TUPLE_SIZE][INPUT_LEN], *ar2[TUPLE_SIZE];
{
  int i, eq;
  eq = 0;
  for(i=0; i < TUPLE_SIZE; ++i) {
    if (strcmp(ar1[i],ar2[i]) == SAME) ++eq;
  }
  return( eq );
}

/* copies situation 2 into situation 1. */
sitcpy(sit1,sit2)
        char sit1[TUPLE_SIZE][INPUT_LEN], *sit2[TUPLE_SIZE];
{
  int i;
  for(i=0; i < TUPLE_SIZE; ++i) strcpy(sit1[i],sit2[i]);
}  


/* returns a prediction.  If can't match exactly, then predict default.
 */
int alg_predict(char *input_data[TUPLE_SIZE],
	    state_table my_memory, int my_default)
{
  int i;
  state_list *listptr = lookup(my_memory, input_data);

  if (listptr) {              /* Found it.   Get the vote */
    i = best_choice(listptr->history, listptr->historysize);
    return( i );
  } else {                    /* predict default */
    return(my_default);
  }
}


alg_update(char *input_data[TUPLE_SIZE], state_table my_memory)
{
  state_list *listptr = lookup(my_memory, input_data);
  int i;
  if (listptr) {
    if (listptr->historysize != AMT) {               /* if not full... */
      ++(listptr->historysize);
    }
    for(i = listptr->historysize - 1; i > 0; --i) {
      listptr->history[i] = listptr->history[i-1];
    }
    listptr->history[0] = correct_index;
  } else {
    insert(my_memory, input_data, correct_index);
  }
}

/* this updates the value in default_index.  Uses global correct_index. */
/* Not using any assumptions about storage outside procedure */
update_default()
{
  static int default_array[DEFAULT_AMT];
  static int num_in_array = 0;
  int i;
    
  if (num_in_array != DEFAULT_AMT) ++num_in_array;
  /* shift everyone over and put correct answer into default_array[0] */
  for(i = num_in_array -1; i > 0; --i)
    default_array[i] = default_array[i-1];
  default_array[0] = correct_index;

  /* update default value */
  default_index = best_choice(default_array,num_in_array);
}



/* this prints out information:
   Top 4 algs, their predictions and total number of mistakes each.
   Our prediction, ID3, and correct answer.
   Our top 3 choices and their weights.
   # mistakes so far of: Us, Default, "no".

   Also, allows to request for info of any algorithm i.
 */
pretty_print(void)
{
  int i,j,k,togo,mistakes, top_four[4];
  char string[100], skip[4], istring[10];
  state_list *listptr;
  static int skipnumber = 0;
  static FILE *outfp = stdout;

  ++seen_so_far;
  if (top_predictions[0] != correct_index) ++our_mistakes;

  if ((seen_so_far % 10 == 0) && aggressive) {
    fprintf(summaryfp, "at example %d, number of predictors still alive: %d\n",
	    seen_so_far, num_left);
  }

  if ((seen_so_far == OUTPUTTIME)) {
    fprintf(summaryfp,
    "name                            number mistakes fraction_correct\n");
    fprintf(summaryfp,"%s\t",example_name);
    fprintf(summaryfp,"%5d\t", seen_so_far);
    fprintf(summaryfp,"%5d \t",our_mistakes);
    fprintf(summaryfp,"%.3f\t\n",1.0 - our_mistakes/((float) seen_so_far));

    for(i=0; i < num_algs; ++i)
      if (!ignore_me[i] && mistake_array[i] == fewest_mistakes) break;
    fprintf(summaryfp,"Top alg was %s: ( ",alg_name_array[i]);
    for(j=0; j < TUPLE_SIZE; ++j) {
      fprintf(summaryfp,"%s ",
	      the_feature_set[findex].features[alg_inputs_array[i][j]]);
    }
    fprintf(summaryfp,")\n\n");
  }

  if (!interactive || skipnumber > 0) {   /* don't print */
    --skipnumber;
    return;
  }

  /* print out data */
  fprintf(outfp,"Example %d: %s\n", seen_so_far, example_name);
  for(i=0; i < the_feature_set[findex].num_features; ++i)
    fprintf(outfp,"%s: %s\n",the_feature_set[findex].features[i],inputs[i]);

  /* print top 4 algs: this is done really slowly and stupidly */
  fprintf(outfp,"top 4:");
  togo = 4;
  mistakes = fewest_mistakes;
  while (togo) {
    for(i = 0; (i < num_algs) && (togo); ++i)
      if (mistake_array[i] == mistakes && !ignore_me[i]) {
	top_four[4 - togo] = i;
	--togo;
      }
    ++mistakes;
  }
  for(i=0; i < 4; ++i) {
    fprintf(outfp,"%7s   ",alg_name_array[top_four[i]]);
    if (i != 3) fprintf(outfp,"        ");
  }
  fprintf(outfp,"\n  ");
  for(i=0; i < 4; ++i) {
    sprintf(string,"(%.7s,%d,%4.2f)",
	    global_out_array[prediction_array[top_four[i]]],
	    mistake_array[top_four[i]], weight_array[top_four[i]]);
    fprintf(outfp,"%18.18s", string);
  }
  fprintf(outfp,"\n");

  /* print out prediction  and correct answer and number of mistakes */
  fprintf(outfp,"Our prediction: %s.    Correct: %s\n",
	 global_out_array[top_predictions[0]], correct_ans);
  fprintf(outfp,"             %d mistakes so far\n", our_mistakes);
  if (aggressive) fprintf(outfp,"    (%d algs left)",num_left);
  fprintf(outfp,"\n");
  /* print top NPRED choices */
  fprintf(outfp," Top %d choices and weights: ",NPRED);
  for(i=0; i < NPRED; ++i)
    if (top_predictions[i] != -1 ) {
      fprintf(outfp,"(%s, %.2f) ",global_out_array[top_predictions[i]],top_weights[i]);
    } else {
      fprintf(outfp,"(none, 0.0)   ");
    }
  fprintf(outfp,"\n");
 

  /* now, get return to continue.  If get an alg name, then display info
     about that algorithm.  If int followed by "s", then treat as skip
     z = zero out, q= quit */
  fprintf(outfp,"\n");
  if (seen_so_far==1) {
    printf("type <cr> for next, <num>s to skip printing of next <num>\n");
    printf("examples, <algname> to see the algorithm, z to zero counts\n");
    printf("and q to quit.\n");
  }
  while(1) {
    gets(string);
    if (string[0] == '\0') return;

    /* On "z", zero out algorithm mistakes.  */
    if (string[0] == 'z') {
      our_mistakes = 0;
      continue;
    }
    if (string[0] == 'q') exit(0);

    if (string[strlen(string)-1] == 's') {
      if (sscanf(string,"%d",&j)) skipnumber = j-1;
      return;
    }
    /* try to compare with an algorithm */
    for(i=0; i < num_algs; ++i) {
      if (strcmp(string, alg_name_array[i]) == SAME) break;
    }
    if (i == num_algs) {
      printf("unknown command/algorithm name.\n");
      continue;
    }
    fprintf(outfp,"Alg. for features %s:    Mistakes %d, weight %5.3f.\n",
	   alg_name_array[i],mistake_array[i],weight_array[i]);
    fprintf(outfp,"Feature set is: ( ");
    for(j=0; j < TUPLE_SIZE; ++j) {
      fprintf(outfp,"%s ", the_feature_set[findex].features[alg_inputs_array[i][j]]);
    }
    fprintf(outfp,")\n");
    for(j=0; j < HASH_SIZE; ++j) {
      for(listptr = alg_array[i][j]; listptr; listptr = listptr->next) {
	fprintf(outfp,"situation: ( ");
	for(k=0;k<TUPLE_SIZE; ++k) fprintf(outfp,"%s ",listptr->situation[k]);
	fprintf(outfp,")        last %d seen: ", listptr->historysize);
	for(k=0; k < listptr->historysize; ++k)
	  fprintf(outfp,"%s ",global_out_array[listptr->history[k]]);
	fprintf(outfp,"\n");
      }
    }
  }
}



/*  Finds out data file and initializes things */
init(int argc, char **argv)
{
  int i,j;
  char filename[100], *ptr, junk[100];
  int tuple[TUPLE_SIZE];

  strcpy(global_out_array[0],"******");
  num_outputs = 1;
  default_index = 0;

  if (argc == 2) { /* print out help */
    printf("\
  First arg is feature set: 0=loc,1=dur,2=start,3=d.o.w,4=big\n\
  Second is what to predict, Third is output file, Fourth is data file\n\
  Fifth is when to stop\n\
  Sixth (if it exists) is 'a' for aggressive, 'l' for lowerbounding weights.\n");
    exit(0);
  }

  if (argc >=3) {
    sscanf(argv[1],"%d",&findex);
    sscanf(argv[2],"%s",to_predict);
    interactive = 0;  /* non-interactive mode */
  } else {
    /* find which feature set to use */
    printf("avaliable feature sets: ");
    for(i=0; i < num_feature_sets; ++i) 
      printf("%s (%d), ", the_feature_set[i].name, i);
    printf("\nenter desired feature set number: ");
    scanf("%d",&findex);
    printf("what feature do you want to predict (type the name)? ");
    scanf("%s",to_predict);
    gets(junk);  /* get newline */
  }
  NUM_INPUTS = the_feature_set[findex].num_features;
  for(num_algs=1,i=0; i < TUPLE_SIZE; ++i)
    num_algs *= (NUM_INPUTS-i);
  for(i=0; i < TUPLE_SIZE; ++i) num_algs = num_algs/(i+1);
  printf("%d algs.\n",num_algs);

  if (argc >= 4)
    strcpy(filename, argv[3]);
  else {
    printf("output file for summary (<cr> for none): ");
    gets(filename);
  }
  if (*filename) {
    if ((summaryfp = fopen(filename,"a")) == NULL)
      printf("can't open file. Not creating summary.\n");
  } else {
    summaryfp = stdout;
  }
  fprintf(summaryfp,"predicting %s with %s\n",to_predict,
	  the_feature_set[findex].name);
  fprintf(summaryfp,"beta is: %g.\n",beta);
  if (argc >= 5) 
    strcpy(filename, argv[4]);
  else {
    printf("Give data file: ");
    gets(filename);
  }
  if (*filename == '\0' || (fileptr = fopen(filename,"r")) == NULL) {
    printf("can't open file '%s'.\n", filename);
    exit(1);
  }
  /* initialize global vars */
  our_mistakes = 0; fewest_mistakes = 0;
  for(i=0; i < num_algs; ++i) {
    mistake_array[i] = 0;
    weight_array[i] = 1.0;
    ignore_me[i] = 0;
  }
  seen_so_far = 0;

  /* set up algorithm names and inputs for each */
  for (j=0; j < TUPLE_SIZE; ++j) tuple[j] = j;
  for(i=0; i < num_algs; ++i) {
    ptr = alg_name_array[i];
    for(j=0; j < TUPLE_SIZE; ++j) {
      alg_inputs_array[i][j] = tuple[j];
      sprintf(ptr,"%d",tuple[j]);
      ptr = alg_name_array[i] + strlen(alg_name_array[i]);
    }
    increment_tuple(tuple);
  }

  /* old program used to ask for start date, but now just start at beginning*/
  read_upto_date(fileptr,1,1,1);

  if (argc >=6) sscanf(argv[5], "%d",&OUTPUTTIME);
  else if (interactive) {
    printf("Time to output (or <cr> for %d): ", OUTPUTTIME);
    gets(junk);
    if (*junk) sscanf(junk,"%d", &OUTPUTTIME);
  }

  throw_out_value = THROW_OUT_VALUE;

  if (argc >= 7) {
    if (*argv[6] == 'a') { aggressive = 1; num_left = num_algs; }
    if (*argv[6] == 'l') lowerbound = 1;
    if (isdigit(*argv[6])) {
      aggressive = 1; num_left = num_algs;
      sscanf(argv[6],"%f", &throw_out_value);
    }
  } else if (interactive) {
    printf("aggressive (throw out algs with weights < %g)? (<cr> for no) ",
	   throw_out_value);
    gets(junk);
    if (isdigit(*junk) || *junk == 'y') {
      aggressive = 1; 
      num_left = num_algs;
      if (isdigit(*junk)) {
	sscanf(junk, "%f",&throw_out_value);
	printf("OK, setting throw out value to %g\n",throw_out_value);
      }
    } else {
      printf("quick on our feet (don't cut weights below %g)? (<cr> for no) ",
	     throw_out_value);
      gets(junk);
      if (*junk == 'y') lowerbound = 1;
    }
  }
  if (aggressive) fprintf(summaryfp,"aggressive mode. threshold = %g\n",
			  throw_out_value);
  if (lowerbound) fprintf(summaryfp,"lowerbounding weights\n");

}

/* this reads in the NUM_INPUTS features. 
   Returns 1 on success, 0 on failure (i.e., end of file).
 */
get_next_data()
{
  int i;
  char line[100];
  fgets(line,100,fileptr);   /* get "next example:" */
  for(i=0; i < NUM_INPUTS; ++i) {
    if (fscanf(fileptr,"%s",line) != 1) return( 0 );
    line[INPUT_LEN-1] = '\0';
    strcpy(inputs[i],line);
  }
  return( 1 );
}
 
/* increments an array of integers indicating a subset TUPLE_SIZE elements
   of {0,...,NUM_INPUTS-1}.  Returns 1 if can, or 0 if we already had the last 
   one.  Eg. TS = 2, NI = 5.  Then, progression is:
                           01,02,03,04,12,13,14,23,24,34.
 */
increment_tuple(val_array)
     int val_array[TUPLE_SIZE];
{
  int i,j;
  for (j=TUPLE_SIZE-1; j > -1; --j)
    if (val_array[j] < NUM_INPUTS - (TUPLE_SIZE - j)) {
      ++val_array[j];
      for(i=j+1; i < TUPLE_SIZE; ++i)
	val_array[i] = val_array[i-1] + 1;
      return( 1 );
    }
  return( 0 );
}



/* get index of correct answer
   Update global variables.

 */
evaluate_result(void)
{
  int j;

  for(j=0; j < num_outputs; ++j) {
    if (strcmp(correct_ans, global_out_array[j]) == SAME) {
      correct_index = j;
      break;
    }
  }
  if (j == num_outputs) {  /* it is a NEW output */
    strcpy(global_out_array[num_outputs], correct_ans);
    correct_index = num_outputs;
    ++num_outputs;
  }
}




/**************hash table functions*************/
int hash_situation(char *s[TUPLE_SIZE])
{
  int i,h;
  char *ptr;
  for(i=0, h=0; i < TUPLE_SIZE; ++i) {
    for(ptr = s[i]; *ptr != '\0'; ++ptr) h = (h*4 + *ptr) % HASH_SIZE;
  }
  return h;
}

state_list *lookup(state_table t, char *s[TUPLE_SIZE])
{
  int index = hash_situation(s);
  state_list *ptr;
  for(ptr = t[index]; ptr; ptr = ptr->next) {
    if (sitcmp(ptr->situation, s) == TUPLE_SIZE) return ptr;
  }
  return NULL;
}
    
insert(state_table t, char *s[TUPLE_SIZE], int value)
{
  int index = hash_situation(s);
  state_list *ptr;
  ptr = (state_list *) calloc(1, sizeof(state_list));
  sitcpy(ptr->situation,s);     /* copy data over */
  ptr->history[0] = value;
  ptr->historysize = 1;                  /* 1 thing in "history" array */
  ptr->next = t[index];
  t[index] = ptr;
}
