
/**********************************************************************
 * $Id: bm-train.c,v 1.3 92/11/30 11:53:48 drew Exp $
 **********************************************************************/

/**********************************************************************
 *   Copyright 1990,1991,1992,1993 by The University of Toronto,
 *		       Toronto, Ontario, Canada.
 * 
 *			 All Rights Reserved
 * 
 * Permission to use, copy, modify, distribute,  and sell this software
 * and its documentation for any purpose is hereby granted without fee,
 * provided  that the above copyright notice  appears in all copies and
 * that both the copyright notice and this permission notice  appear in
 * supporting documentation, and  that  the  name of The University  of
 * Toronto  not  be used  in advertising   or publicity pertaining   to
 * distribution  of   the software   without  specific, written   prior
 * permission.  The  University  of Toronto  makes   no representations
 * about the  suitability  of  this software  for  any purpose.   It is
 * provided "as is" without express or implied warranty.
 *
 * THE  UNIVERSITY OF  TORONTO DISCLAIMS ALL WARRANTIES  WITH REGARD TO
 * THIS SOFTWARE,  INCLUDING ALL  IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS, IN NO EVENT  SHALL THE UNIVERSITY  OF TORONTO BE LIABLE
 * FOR ANY SPECIAL,  INDIRECT OR CONSEQUENTIAL  DAMAGES  OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF  USE, DATA OR PROFITS,  WHETHER IN
 * AN ACTION OF CONTRACT, NEGLIGENCE  OR OTHER TORTIOUS ACTION, ARISING
 * OUT  OF OR  IN  CONNECTION   WITH  THE  USE OR  PERFORMANCE  OF THIS
 * SOFTWARE.
 *
 **********************************************************************/

 /*********************************************************************
 *
 *  MFT/FEM/Boltzmann Modules written by     Evan W. Steeg
 *                                           Dept. of Computer Science
 *  August 1991                              Univ. of Toronto
 *
 **********************************************************************/

#include <stdio.h>
#include <math.h>

#include <xerion/simulator.h>
#include "bm.h"
#include "bm-train.h"

#define CLEAR1		0
#define CLEAR2		1
#define POSITIVE	2
#define NEGATIVE	3

/**********************************************************************/
static void	clampOutput         ARGS((Unit	unit, void	*data)) ;
static void	setOutput           ARGS((Unit	unit, void	*data)) ;
static void     setOldOutput        ARGS((Unit  unit, void      *data)) ;
static void	updateActivity      ARGS((Unit	unit, void	*data)) ;
static void	setIncomingProducts ARGS((Unit	unit, void	*data)) ;
static void     updatePlusCorr      ARGS((Unit  unit, void      *data)) ;
static void     updateMinusCorr     ARGS((Unit  unit, void      *data)) ;
static void     zeroLinks           ARGS((Unit  unit, void      *data)) ;
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateErrorDeriv
 *	Description:	procedure for calculating error and associated
 *			derivatives for a Boltzmann Machine.
 *			It performs a positive annealed phase, a
 *			negative annealed phase, and then updates
 *			the gradients for each example.
 *	Parameters:	
 *		Net		net - the network to use
 *		ExampleSet	exampleSet - the example set to use
 *	Return Value:	
 *		NONE
 ***********************************************************************/
void		calculateErrorDeriv(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  Real		tMin  = MtMin(net) ;
  Real		tMax  = MtMax(net) ;
  Real		tDecay = MtDecay(net) ;
  int		numExamples ;
  int           numSamplingSweeps; 
  int           numRelaxations; 
  int           sweep;
  int           relaxPhase;

  Mrunning(net) = TRUE ;
  net->error = 0.0 ;
  netForAllUnits(net, ALL, zeroLinks, NULL) ;


  numSamplingSweeps = MnumSamplingSweeps(net) ;
  numRelaxations = MnumRelaxations(net) ;
  MrelaxSweepCount(net) = 0 ;
  MrelaxSweepCountAve(net) = 0.0 ;

  /* Run through all the training cases. */
  netForAllUnits(net, ALL, setIncomingProducts, (void *)CLEAR2) ;
  for (numExamples = 0 ; numExamples < net->batchSize ; ++numExamples) {
    MgetNext(exampleSet) ;
    /* Go through numRelaxations relaxations (anneals), and at the
       end of each do numSamplingSweeps s_i*s_j calculations.               
       Do either asynchronous (with random net traversal) or
       synchronous (with fast inorder traversal) updating. */

    netForAllUnits(net, ALL, setIncomingProducts, (void *)CLEAR1) ;
    for (relaxPhase =0; relaxPhase < numRelaxations; relaxPhase++) {
      /* Positive phase */
      netForAllUnits(net, ~(INPUT | OUTPUT | BIAS), setOutput, NULL) ;
      netForAllUnits(net, OUTPUT, clampOutput, NULL) ;
      if (MsynchronousUpdate(net)) 
	netForAllUnits(net, ALL, setOldOutput, NULL) ;

      /* If there noAnnealInPosPhase set, e.g., if no links between any
         hidden units so no need to anneal in clamped phase, then don't
         anneal, but just do one relaxation.                            */

      if (MnoAnnealInPosPhase(net)) {
       Mtemp(net) = tMin;
       if (MsynchronousUpdate(net))
          netForAllUnits(net, ~(INPUT | OUTPUT | BIAS), updateActivity, NULL) ;
       else
          netForAllUnitsRandom(net, ~(INPUT | OUTPUT | BIAS),
                               updateActivity, NULL) ;
       MrelaxSweepCount(net)++ ;
      }


      else  /* Go ahead and do pos phase annealing */
      for (Mtemp(net) = tMax ; Mtemp(net) > tMin ; Mtemp(net) *= tDecay) {
	if (MsynchronousUpdate(net))
	  netForAllUnits(net, ~(INPUT | OUTPUT | BIAS), updateActivity, NULL) ;
	else
	  netForAllUnitsRandom(net, ~(INPUT | OUTPUT | BIAS), 
			       updateActivity, NULL) ;
	MrelaxSweepCount(net)++ ;
      } 
  
      /* Positive phase stats sampling */
      Mtemp(net)=tMin;
      for (sweep=0; sweep < numSamplingSweeps; sweep++) {
	if (MsynchronousUpdate(net))
	  netForAllUnits(net, ~(INPUT | OUTPUT | BIAS), 
			 updateActivity, NULL);
	else 
	  netForAllUnitsRandom(net, ~(INPUT | OUTPUT | BIAS),
			       updateActivity, NULL);
	MrelaxSweepCount(net)++ ;
	netForAllUnits(net, ALL, setIncomingProducts, (void *)POSITIVE) ;
      }
    }

    netForAllUnits(net, ALL, updatePlusCorr, NULL) ;

    for (relaxPhase =0; relaxPhase < numRelaxations; relaxPhase++) {
      /* Negative phase */
      netForAllUnits(net, ~(INPUT | BIAS), setOutput, NULL) ;
      if (MsynchronousUpdate(net))
	netForAllUnits(net, ALL, setOldOutput, NULL) ;
      for (Mtemp(net) = tMax ; Mtemp(net) > tMin ; Mtemp(net) *= tDecay) {
	if (MsynchronousUpdate(net))
	  netForAllUnits(net, ~(INPUT | BIAS), updateActivity, NULL) ;
	else
	  netForAllUnitsRandom(net, ~(INPUT | BIAS), updateActivity, NULL) ;
	MrelaxSweepCount(net)++ ;
      }


      /* Negative phase stats sampling */
      Mtemp(net)=tMin;
      for (sweep=0; sweep<numSamplingSweeps; sweep++) {
	if (MsynchronousUpdate(net))
	  netForAllUnits(net, ~(INPUT | BIAS), updateActivity, NULL);
	else 
	  netForAllUnitsRandom(net, ~(INPUT | BIAS), updateActivity, NULL);
	MrelaxSweepCount(net)++ ;
	netForAllUnits(net, ALL, setIncomingProducts, (void *)NEGATIVE) ;
      }
    }

    netForAllUnits(net, ALL, updateMinusCorr, NULL) ;

    /* update the error for the net */
    MupdateNetActivities(net) ;
  }

  if (numExamples <= 0)
    IErrorAbort("calculateErrorDeriv: no examples processed") ;

  MnumExamples(net) = numExamples;
  MrelaxSweepCountAve(net) = (Real)MrelaxSweepCount(net)/(Real)numExamples ;

  /* gradient update */
  MupdateNetGradients(net) ;

  /* update the cost after everything else is done */
  MevaluateCostAndDerivs(net) ;

  Mrunning(net) = FALSE ;
}
/**********************************************************************/


/***********************************************************************
 *	Name:		calculateError
 *	Description:	procedure for calculating the error on an example
 *			set for a mean field net. It calls
 *			the net's updateActivities procedure for each
 *			example.
 *	Parameters:	
 *		Net		net - the network to test
 *		ExampleSet	exampleSet - the example set to test on
 *	Return Value:	
 *		1 if the error is less than the errorTolerance, 0 otherwise
 ***********************************************************************/
void		calculateError(net, exampleSet)
  Net		net ;
  ExampleSet	exampleSet ;
{
  int		numExamples ;

  Mrunning(net) = TRUE ;
  net->error = 0.0 ;
  for (numExamples = 0 ; numExamples < net->batchSize  ; ++numExamples) {
    MgetNext(exampleSet) ;
    MupdateNetActivities(net) ;
  }

  if (numExamples <= 0)
    IErrorAbort("calculateError: no examples processed") ;

  /* update the cost after everything else is done */
  MevaluateCost(net) ;

  Mrunning(net) = FALSE ;
}
/**********************************************************************/


/**********************************************************************/
static void	clampOutput(unit, data)
  Unit		unit ;
  void		*data ;
{
  unit->output = unit->target ;
}
/**********************************************************************/
static void	setOutput(unit, data)
  Unit		unit ;
  void		*data ;
{
  /* Set to random val of 0 or 1  */
  unit->output = (Real)(0.5 < MRANDOM());
}
/**********************************************************************/
static void     setOldOutput(unit, data)
  Unit          unit ;
  void          *data ;
{
  unit->extension->old = unit->output ;
}
/**********************************************************************/
static void	updateActivity(unit, data)
  Unit		unit ;
  void		*data ;
{
  MupdateUnitActivity(unit) ;
}
/**********************************************************************/
static void     updatePlusCorr(unit, data)
  Unit          unit ;
  void          *data ;
{
  MupdatePlusCorr(unit) ;
}
/**********************************************************************/
static void     updateMinusCorr(unit, data)
  Unit          unit ;
  void          *data ;
{
  MupdateMinusCorr(unit) ;
}
/**********************************************************************/
static void	setIncomingProducts(unit, data)
  Unit		unit ;
  void		*data ;
{
  int		mode = (int)data ;
  int		numIncoming = unit->numIncoming ;
  Link		*incoming   = unit->incomingLink ;
  int		idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx) {
    Link	link = incoming[idx] ;
    if (mode == POSITIVE)
      McorrCount(link) += link->preUnit->output*link->postUnit->output ;
    else if (mode == NEGATIVE)
      McorrCount(link) += link->preUnit->output*link->postUnit->output ;
    else if (mode == CLEAR1)
      McorrCount(link) = 0.0 ;
    else if (mode == CLEAR2)
      MminusCorr(link) = MplusCorr(link) = 0.0 ;
  }
}
/**********************************************************************/


/*********************************************************************
 *      Name:           zeroLinks
 *      Description:    zeroes the deriv fields in the incoming links
 *                      to a unit
 *      Parameters:
 *        Unit          unit - the unit whose links are to be zeroed
 *        void          *data - UNUSED
 *      Return Value:
 *        static void   zeroLinks - NONE
 *********************************************************************/
static void     zeroLinks(unit, data)
  Unit          unit ;
  void          *data ;
{
  int   numIncoming = unit->numIncoming ;
  Link  *incoming   = unit->incomingLink ;
  int   idx ;

  for (idx = 0 ; idx < numIncoming ; ++idx)
    incoming[idx]->deriv = 0.0 ;
}
/********************************************************************/
