
/**********************************************************************
 * $Id: bm.c,v 1.2 92/11/30 11:53:50 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/


 /*********************************************************************
 *
 *  MFT/FEM/Boltzmann Modules written by     Evan W. Steeg
 *                                           Dept. of Computer Science
 *  August 1991                              Univ. of Toronto
 *
 **********************************************************************/


#include <stdio.h>
#include <math.h>

#include <xerion/useful.h>
#include <xerion/version.h>
#include <xerion/simulator.h>
#include "bm.h"
#include "bm-train.h"
#include "help.h"

static void	initNet     ARGS((Net	 net)) ;
static void	deinitNet   ARGS((Net	 net)) ;
static void	initGroup   ARGS((Group	group)) ;
static void	deinitGroup ARGS((Group group)) ;
static void	initUnit    ARGS((Unit	 unit)) ;
static void	deinitUnit  ARGS((Unit	 unit)) ;
static void	initLink    ARGS((Link	 link)) ;
static void	deinitLink  ARGS((Link	 link)) ;

static void     setOutput           ARGS((Unit  unit, void      *data)) ;
static void     setOldOutput        ARGS((Unit  unit, void      *data)) ;

static void	netActivityUpdate   ARGS((Net	net)) ;
static void	unitActivityUpdate  ARGS((Unit	unit)) ;
static void	unitPlusCorrUpdate  ARGS((Unit	unit)) ;
static void	unitMinusCorrUpdate ARGS((Unit	unit)) ;
static void     unitGradientUpdate  ARGS((Unit  unit)) ;
static void	updateActivity      ARGS((Unit	unit, void	*data)) ;
static void	updateNetError      ARGS((Unit	unit, void	*data)) ;

static int metrop ARGS((double delta_E, double T, double unitOutput)) ;
static int heatbath ARGS((double delta_E, double T)) ;

struct TRACE    inRelaxation ;


/***********************************************************************
 *	Name:		main 
 *	Description:	the main function, used for the xerion simulator
 *	Parameters:	
 *		int	argc	- the number of input args
 *		char	**argv  - array of argument strings from command line
 *	Return Value:	
 *		int	main	- 0
 ***********************************************************************/
int main (argc, argv)
  int	argc ;
  char	**argv ;
{
  authors = "Evan Steeg" ;

  MSEEDRAND(longval);

  /* Insert any private initialization routines here */
  setCreateNetHook  (initNet) ;
  setDestroyNetHook (deinitNet) ;

  setCreateGroupHook(initGroup) ;
  setDestroyGroupHook (deinitGroup) ;

  setCreateUnitHook  (initUnit) ;
  setDestroyUnitHook (deinitUnit) ;

  setCreateLinkHook  (initLink) ;
  setDestroyLinkHook (deinitLink) ;

  /* Perform initialization of the simulator */
  IStandardInit(&argc, argv);

  /* Enter loop that reads commands and handles graphics */
  ICommandLoop(stdin, stdout, NULL);

  return 0 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initNet 
 *	Description:	sets the error calculation procedures for
 *			the net. As well, changes the activityUpdateProc,
 *			allocates memory for the extension and initializes
 *			some values in it
 *	Parameters:	
 *		Net	net - the net to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initNet(net)
  Net	net ;
{
  net->calculateErrorDerivProc = calculateErrorDeriv ;
  net->calculateErrorProc      = calculateError ;

  net->activityUpdateProc = netActivityUpdate ;

  net->extension = (NetExtensionRec *)calloc(1, sizeof(NetExtensionRec)) ;

  MtMax(net)   = T_MAX ;
  MtMin(net)   = T_MIN ;
  MtDecay(net) = T_DECAY ;
  MannealMethod(net) = HEATBATH ;
  Mrunning(net) = FALSE ;
}
/**********************************************************************/
static void	deinitNet(net)
  Net	net ;
{
  if (net->extension != NULL)
    free(net->extension) ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		initGroup 
 *	Description:	sets the activity and weight updates for the units
 *			in a group. Also sets up the extension record.
 *	Parameters:	
 *		Group	group - the group to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initGroup(group)
  Group	group ;
{
  group->unitActivityUpdateProc = unitActivityUpdate ;
  group->unitGradientUpdateProc = unitGradientUpdate ;

  group->extension = (GroupExtensionRec *)calloc(1, sizeof(GroupExtensionRec));
  group->extension->unitPlusCorrUpdateProc  = unitPlusCorrUpdate ;
  group->extension->unitMinusCorrUpdateProc = unitMinusCorrUpdate ;
}
/**********************************************************************/
static void     deinitGroup(group)
  Group group ;
{
  if (group->extension != NULL)
    free(group->extension) ;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		initLink 
 *	Description:	allocates the memory for the link extension record
 *	Parameters:	
 *		Link	link - the link to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initLink(link)
  Link	link ;
{
  link->extension = (LinkExtensionRec *)calloc(1, sizeof(LinkExtensionRec)) ;
}
/**********************************************************************/
static void	deinitLink(link)
  Link	link ;
{
  if (link->extension != NULL)
    free(link->extension) ;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		initUnit 
 *	Description:	allocates the memory for the Unit extension record
 *	Parameters:	
 *		Unit	unit - the unit to act on
 *	Return Value:	NONE
 ***********************************************************************/
static void	initUnit(unit)
  Unit	unit ;
{
  unit->extension = (UnitExtensionRec *)calloc(1, sizeof(UnitExtensionRec)) ;
}
/**********************************************************************/
static void	deinitUnit(unit)
  Unit	unit ;
{
  if (unit->extension != NULL)
    free(unit->extension) ;
}
/**********************************************************************/

/***********************************************************************
 *	Name:		unitError
 *	Description:	calculates the error value of a unit assuming
 *                      that the global net.error measure is sum of 
 *                      squared errors.  
 *		        	
 *	Parameters:	
 *		const Unit	unit - the unit to calclate the error of
 *	Return Value:	
 *		Real	unitError - the error of the unit
 ***********************************************************************/
static Real	unitError(unit)
  const Unit	unit ;
{
  return (Real)square((unit->target - Moutput(unit))) / 2.0 ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		dotProduct
 *	Description:	calculates the dot product of all incoming
 *			links for a unit and stores it in the totalinput
 *			field of the unit.
 *	Parameters:	
 *		const Unit	unit - the unit to calculate the dot 
 *					product for
 *	Return Value:	
 *		Real	dotProduct - the dot product
 *   
 ***********************************************************************/
static Real	dotProduct(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link	*incoming   = unit->incomingLink ;
  Real	totalInput ;
  int	idx ;

  totalInput = 0.0 ;
  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link = incoming[idx] ;

   /* synch updating => have to store all activ values for one sweep  */
   if (MsynchronousUpdate(unit->net))
     totalInput += link->weight * link->preUnit->extension->old ;
   else 
     totalInput += link->weight * link->preUnit->output ;
  }
  unit->totalInput = totalInput ;

  return totalInput ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		netActivityUpdate
 *	Description:	activates the network assuming an example has
 *			been input already. it does the anealing for
 *			the example.
 *	Parameters:	
 *		Net	net - the net to activate.
 *	Return Value:	
 ***********************************************************************/
static void	updateNetError(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->net->error += unitError(unit) ;
}
/**********************************************************************/
static void	updateActivity(unit, data)
  Unit		unit ;
  void		*data ;
{
  MupdateUnitActivity(unit) ;
}
/**********************************************************************/
static void	netActivityUpdate(net)
  Net		net ;
{
  Real		tMin   = MtMin(net) ;
  Real		tMax   = MtMax(net) ;
  Real		tDecay = MtDecay(net) ;
  int           delayCount = net->extension->delayCount ;
  int           i ;

  netForAllUnits(net, ~(INPUT | BIAS), setOutput, NULL) ;  
  if (MsynchronousUpdate(net)) netForAllUnits(net, ALL, setOldOutput, NULL) ;


  for (Mtemp(net) = tMax ; Mtemp(net) > tMin ; Mtemp(net) *= tDecay) {
    if (MsynchronousUpdate(net)) 
      netForAllUnits(net, ~(INPUT | BIAS), updateActivity, NULL) ;
    else 
      netForAllUnitsRandom(net, ~(INPUT | BIAS), updateActivity, NULL) ;
    if (Mrunning(net) != TRUE) {
      IDoTrace(&inRelaxation) ;
      for (i = 0; i < delayCount*1000 ; i++)
	;
    }
  }

  netForAllUnits(net, OUTPUT, updateNetError, NULL) ;
}
/**********************************************************************/
/*
 * Metropolis and Heatbath stochastic state-change methods, used in 
 * in conjunction with the unit activity update code.  Basically,
 * 
 *  Metrop:    if changing current state reduces energy, do it
 *             else do it anyway, with prob exp(-delta_E/T)
 * 
 *  Heatbath:  set unit activ = 1 with prob  1/(1+ exp(delta_E/T))
 *               where delta_E is  E[unit on] - E[unit off].  
 *
 **********************************************************************/
static int metrop(delta_E,T,unitOutput)
  double delta_E, T, unitOutput;
{
  int choice ;  
  if (!(int)rint(unitOutput)) delta_E = 0.0 - delta_E ;
  choice = ((delta_E > 0.0) || (MRANDOM() <  exp(delta_E / T)));
  return (choice);
}
/**********************************************************************/
static int heatbath(delta_E,T)
 double delta_E, T;
{
  return ((MRANDOM() < 1.0/(1.0 + exp((delta_E / T)))));
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitActivityUpdate
 *	Description:	Update the activation a_i ("spin s_i") if unit i.
 *                      Calc weighted sum of inputs, then use metrop or
 *                      heatbath as described above.
 *	
 *	Parameters:	
 *		Unit	unit - the unit to activate.
 *	Return Value:	NONE
 ***********************************************************************/
static void	unitActivityUpdate(unit)
  Unit		unit ;
{
  int decision;

  double deltaE ;

  unit->totalInput = dotProduct(unit) ;

  deltaE = (double)(0 - unit->totalInput); 

  if (USE_METROPOLIS(unit->net)) {
    decision = metrop(deltaE,
		      (double)Mtemp(unit->net),(double)Moutput(unit));
    /* If decision is yes, then flip state.  */
    if (decision && (int)rint(Moutput(unit)))	Moutput(unit) = 0.0;
    else if (decision) 			     	Moutput(unit) = 1.0;

  } else if (USE_HEATBATH(unit->net)) {
    decision = heatbath(deltaE,(double)Mtemp(unit->net));
    /* If decision is yes, then set output = 1.  */
    if (decision) Moutput(unit) = 1.0;
    else	  Moutput(unit) = 0.0;
  }

  if (MsynchronousUpdate(unit->net)) 
    unit->extension->old = Moutput(unit);
}
/**********************************************************************/


/***********************************************************************
 *      Name:           unitGradientUpdate
 *      Description:    updates the gradients of all the incoming links
 *                      to a unit.  
 *                      Deriv= = (s_i*s_j)plus - (s_i*s_j)minus 
                        Want to follow -gradient.
 *      Parameters:
 *              Unit    unit - the unit to update the gradients of.
 *      Return Value:   NONE
 ***********************************************************************/
static void     unitGradientUpdate(unit)
  Unit  unit ;
{
  int   numIncoming = unit->numIncoming ;
  Link  *incoming   = unit->incomingLink ;
  int   idx ;
  Real divisor ;

  /* Correlations s_i * s_j have been gathered through sampling over 
     annealing runs, sampling sweeps at equilibrium, and examples.  So
     divide totals to get averages.                                     */

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link  = incoming[idx] ;
    divisor =  (Real)MnumSamplingSweeps(unit->net) 
      * (Real)MnumRelaxations(unit->net)
	* (Real)MnumExamples(unit->net);
    link->deriv = (MminusCorr(link)/divisor) - (MplusCorr(link)/divisor) ;
  }
}
/**********************************************************************/

/***********************************************************************
 *	Name:		unitPlusCorrUpdate
 *	Description:	updates the Plus Phase Correlation of all the 
 *                      incoming links to a unit.
 *	Parameters:	
 *		Unit	unit - the unit to update the weights of.
 *	Return Value:	NONE
 ***********************************************************************/
static void	unitPlusCorrUpdate(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link  *incoming   = unit->incomingLink ;
  int   idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link  = incoming[idx] ;

    MplusCorr(link) += McorrCount(link) ;
    McorrCount(link) = 0.0;
  }
}
/**********************************************************************/


/***********************************************************************
 *	Name:		unitMinusCorrUpdate
 *	Description:	updates the Minus Phase Correlation of all the 
 *                      incoming links to a unit.
 *	Parameters:	
 *		Unit	unit - the unit to update the weights of.
 *	Return Value:	NONE
 ***********************************************************************/
static void	unitMinusCorrUpdate(unit)
  Unit	unit ;
{
  int	numIncoming = unit->numIncoming ;
  Link  *incoming   = unit->incomingLink ;
  int   idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link        link  = incoming[idx] ;
    MminusCorr(link) += McorrCount(link) ; 
    McorrCount(link) = 0.0;
  }

}
/**********************************************************************/

/**********************************************************************
 *   Set outputs to random vals before annealing.                     * 
 **********************************************************************/
static void     setOutput(unit, data)
  Unit          unit ;
  void          *data ;
{
  /* Set to random val of 0 or 1  */

  unit->output = (Real)(0.5 < MRANDOM());
}
/**********************************************************************
 *  Store outputs from sweep t for use in sweep t+1                   * 
 **********************************************************************/
static void     setOldOutput(unit, data)
  Unit          unit ;
  void          *data ;
{
  unit->extension->old = unit->output ;
}

/**********************************************************************/
Real		square(x)
  double	x ;
{
  return (Real) (x * x) ;
}
/**********************************************************************/
