//
// neuron.hh
//
// Made by Guillaume Stordeur
// Login   <kami@GrayArea.Masaq>
//
// Started on  Thu Aug  1 04:59:35 2002 Guillaume Stordeur
// Last update Mon May  5 21:45:49 2003 Guillaume Stordeur
//

#ifndef   	NEURON_HH_
# define   	NEURON_HH_

# include <cassert>
# include <vector>
# include "utils.hh"
# include "Matrix.hh"

#define ACT_STRING(T, S)		\
 switch(T)				\
  {					\
   case ACT_SGN:			\
    S = "sgn"; break;			\
   case ACT_LINEAR:			\
    S = "linear"; break;		\
   case ACT_SIGMOID:			\
    S = "sigmoid"; break;		\
   case ACT_SIGMOID_APPROX:		\
    S = "sigmoid_approx"; break;	\
   case ACT_TANH:			\
    S = "tanh"; break;			\
   case ACT_TANH_APPROX:		\
    S = "tanh_approx"; break;		\
   case ACT_GAUSS:			\
    S = "gauss"; break;			\
   default:				\
    S = "unknown"; break;		\
}

#define DWSUM		1 << 1
#define DELTAE		1 << 2
#define MOMENTUM	1 << 3
#define OLD_DELTAE	1 << 4
#define DECAY_DELTAE	1 << 5
#define TRI		1 << 6
#define OUTPUT		1 << 7

namespace NeuralNet
{

typedef enum e_ActivationFunction
  {
    ACT_SGN,
    ACT_LINEAR,
    ACT_SIGMOID,
    ACT_SIGMOID_APPROX,
    ACT_TANH,
    ACT_TANH_APPROX,
    ACT_GAUSS,
  } ActivationFunctionType;




//--------------------------------------
// Main Neuron class, all other classes
// are derived from this one.
//
class Neuron
{
public:
  // Constructor
  Neuron() { _dwsum = _s = _output = 0; _fixed = _recurrent = _timeLagged = false; }
  virtual ~Neuron() {}

  //	Add a connection to neuron's input
  void	addConnection(Neuron *src, float weight);
  void	addConnection(Neuron *src); // random weight

  //	Update weights directly (stochastic mode)
  void	updateBackpropStochastic(float lRate,
				 float moment,
				 float delta);

  // update deltae for batch mode
  virtual void	updateBatch(float delta);

  // final weight change functions for batch mode
  // rprop
  virtual void	updateWeights(float nPlus, float nMinus,
			      float deltaMin, float deltaMax,
			      bool errUp);
  // backprop
  virtual void	updateWeights(float lRate, float moment);
  // quickprop
  virtual void	updateWeights(float lRate, float moment, float mu);

  //	Returns the index of n in the _inputNeurons vector
  //	-1 is returned if not present.
  int	isInputNeuron(Neuron *n);

  //-=-=-=-=-=-=-=-=-=-=-=-=-=-=
  // INTERFACE

  void	setWeight(unsigned int i, float w)
  { assert(i < _weights.size()); _weights[i] = w; }

  void	setWeights(Matrix<double>& m);

  void	setTri(float t)
  { for (unsigned int i = 0; i < _tri.size(); i++) _tri[i] = t; }

  void	setOldDeltae(float t)
  { for (unsigned int i = 0; i < _oldDeltae.size(); i++) _oldDeltae[i] = t; }

  void	clearOldDeltae()
  { for (unsigned int i = 0; i < _oldDeltae.size(); i++) _oldDeltae[i] = 0; }

  void	setFixed(bool b) { _fixed = b; }
  void	setTimeLagged(bool b) { _timeLagged = b; }

  void	incDwsum(float a) { _dwsum += a; }

  void	clearDwsum() { _dwsum = 0; }

  void	clearDeltae()
  { for (unsigned int i = 0; i < _deltae.size(); i++) _deltae[i] = 0; }

  void	clearMomentum()
  { for (unsigned int i = 0; i < _momentum.size(); i++) _momentum[i] = 0; }

  void	decayDeltae(float decay)
  {
    for (unsigned int i = 0; i < _inputNeurons.size(); i++)
      _deltae[i] = decay * _weights[i];
  }
  
  bool	getRecurrent() const { return _recurrent; }
  
  void	setRecurrent(bool i) { _recurrent = i; }
  
  // All in one clear/set function
  virtual  void	clearset(int flags, float decay, float tri)
  {
    //assert(_deltae.size() == _oldDeltae.size() ==
    //_momentum.size() == _tri.size() == _inputNeurons.size());
    if (flags & DWSUM)
      _dwsum = 0;
    if (flags & OUTPUT)
      _output = 0;
    if (flags == DWSUM || flags == OUTPUT)
      return;
    for (unsigned int i = 0; i < _inputNeurons.size(); i++)
      {
	if (flags & TRI)
	  _tri[i] = tri;
	if (flags & OLD_DELTAE)
	  _oldDeltae[i] = 0;
	if (flags & DELTAE)
	  _deltae[i] = 0;
	if (flags & DECAY_DELTAE)
	  _deltae[i] = decay * _weights[i];
	if (flags & MOMENTUM)
	  _momentum[i] = 0;
      }
  }

  unsigned int getNBInputs() const { return _inputNeurons.size(); }

  virtual float	getOutput() const { return _output; }

  bool	getFixed() const { return _fixed; }
  bool	getTimeLagged() const { return _timeLagged; }


  float	getDwsum() const { return _dwsum; }

  float	getWeight(unsigned int i) const
  { assert(i < _weights.size()); return _weights[i]; }

  Matrix<double>	getWeights() const
  {
    std::vector<double> w(_weights.size());
    for (unsigned int i = 0; i < _weights.size(); i++)
      w[i] = _weights[i];
    Matrix<double> m(w);
    return m;
  }

  Neuron	*getInputNeuron(unsigned int i) const
  { assert(i < _inputNeurons.size()); return _inputNeurons[i]; }

  //-=-=-=-=-=-=-=-=-=-=-=-=-=-=
  // NEURON SPECIFIC FUNCTIONS

  //	Get the derivative F'(s)
  virtual float	getFPrime(void) const = 0;

  //    Process inputs and calculate activation (output)
  virtual float refreshOutput(void) = 0;

  //	Display info to standard output
  virtual void	display(void) const = 0;

protected:
  //	Calc weighted sum
  void	_calcWeightedSum(void);

  //	neurons that have their output connected to this neuron
  std::vector<Neuron*>	_inputNeurons;

  //	weight vector, 1st elem being the threshold
  std::vector<float>	_weights;

  //	learning momentum vector
  std::vector<float>	_momentum;

  //
  std::vector<float>	_tri;

  //
  std::vector<float>	_deltae, _oldDeltae;

  //	last output, updated when output function is called
  float	_output;

  //	weighted sum
  float	_s;

  //	delta-weight sum, used in backpropagation
  float	_dwsum;

  //	fixed-weights and timeLagged
  bool	_fixed, _timeLagged;
  
  //	Neuron has it's input fed from an other neuron in the network
  bool	_recurrent;
};

//--------------------------------------
// Threshold Neuron
// * output = 1
// This neuron is connected to a
// threshold connection of another neuron,
// and does not take any inputs.
// It always returns 1.
//
class ThresholdNeuron : public Neuron
{
public:
  //	Constructor
  ThresholdNeuron() { _output = 1; }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void) { return _output; }

  //	Get the derivative F'(s)
  float	getFPrime(void) const { return 1; }

  //	Display info to standard output
  void	display(void) const {}
};


//--------------------------------------
// Input Neuron
// * output = input
// this neuron has 1 input connected to its output
//
class InputNeuron : public Neuron
{
public:
  //	Constructor
  InputNeuron(float input = 0) { _output = input; }


  //	Set up the input
  void	setInput(float input) { _output = input; }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void)
  {
    if (_inputNeurons.size() != 0)
      _output = _inputNeurons[0]->getOutput();
    return _output;
  }

  //	Get the derivative F'(s)
  float	getFPrime(void) const { return 1; }

  //	Display info to standard output
  void	display(void) const {}
};


//--------------------------------------
// Sigmoid activation Neuron
// * output = sigmoid(activation)
//
class SigmoidNeuron : public Neuron
{
public:
  //	Constructor
  SigmoidNeuron(ThresholdNeuron	*n,
		float		threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //	Get the derivative F'(s)
  float	getFPrime(void) const { return (_output * (1 - _output)); }

  //	Display info to standard output
  void	display(void) const;


protected:

  //	Sigmoid non-linear activation function
  inline float	_sigmoid(float s);

};

//--------------------------------------
// Approximated Sigmoid activation Neuron
// * output = sigmoid_approx(activation)
//
class ApproxSigmoidNeuron : public Neuron
{
public:
  //	Constructor
  ApproxSigmoidNeuron(ThresholdNeuron	*n,
		      float		threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //	Get the derivative F'(s)
  float	getFPrime(void) const { return (fabs(_s) >= 1 ? 0 : 1 - fabs(_s)); }

  //	Display info to standard output
  void	display(void) const;


protected:

  //	Approximated Sigmoid non-linear activation function
  inline float	_sigmoid_approx(float s);

};

//--------------------------------------
// Gauss activation Neuron
// * output = gauss(activation)
//
class GaussNeuron : public Neuron
{
public:
  //	Constructor
  GaussNeuron(ThresholdNeuron	*n,
	      float		threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //	Get the derivative F'(s)
  float	getFPrime(void) const { return (-2 * _s * _output); }

  //	Display info to standard output
  void	display(void) const;


protected:

  //	Tanh non-linear activation function
  inline float	_gauss(float s);

};

//--------------------------------------
// Tanh activation Neuron
// * output = tanh(activation)
//
class TanhNeuron : public Neuron
{
public:
  //	Constructor
  TanhNeuron(ThresholdNeuron	*n,
	     float		threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //	Get the derivative F'(s)
  float	getFPrime(void) const { return (1 - _output * _output); }

  //	Display info to standard output
  void	display(void) const;


protected:

  //	Tanh non-linear activation function
  inline float	_tanh(float s);

};

//--------------------------------------
// Approximated Tanh activation Neuron
// * output = tanh_approx(activation)
//
class ApproxTanhNeuron : public Neuron
{
public:
  //	Constructor
  ApproxTanhNeuron(ThresholdNeuron	*n,
		   float		threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //	Get the derivative F'(s)
  float	getFPrime(void) const
  {
    if (fabs(_s) > 1.92033)
      return 0;
    else if (_s > 0 && _s <= 1.92033)
      return (-0.52074 * _s + 1);
    else
      return  (0.52074 * _s - 1);
  }

  //	Display info to standard output
  void	display(void) const;


protected:

  //	Approximated Tanh non-linear activation function
  inline float	_tanh_approx(float s);

};

//----------------------------------------
// Sgn activation Neuron
// * output = 0 if activation < 0
//            1 if activation >= 0
//
class SgnNeuron : public Neuron
{
public:
  //	Constructor
  SgnNeuron(ThresholdNeuron	*n,
	    float		threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //	Get the derivative F'(s)
  // FIXME: not derivable
  float	getFPrime(void) const { return 1; }

  //	Display info to standard output
  void	display(void) const;


protected:

};


//----------------------------------------
// Linear activation Neuron
// * output = activation
//
class LinearNeuron : public Neuron
{
public:
  //	Constructor
  LinearNeuron(ThresholdNeuron	*n,
	       float		threshold = 0) { addConnection(n, threshold); }

  //    Process inputs and calculate activation (output)
  float refreshOutput(void);

  //	Get the derivative F'(s)
  float	getFPrime(void) const { return 1; }

  //	Display info to standard output
  void	display(void) const;


protected:

};


//
//	Alloc a new neuron
//
Neuron	*newNeuron(ActivationFunctionType	type,
		   ThresholdNeuron		*tneuron,
		   float			threshold);

} // end NeuralNet namespace

#endif	    /* !NEURON_HH_ */
