#include <stdio.h>
#include <stdlib.h>
#include <math.h>

float wa11,wa12,wa21,wa22,wb1,wb2;
float dwa11,dwa12,dwa21,dwa22,dwb1,dwb2;
float dowa11=0.0,dowa12=0.0,dowa21=0.0,
      dowa22=0.0,dowb1=0.0,dowb2=0.0;
float grad_out,grad_hid1,grad_hid2;
float hid1_i,hid2_i,out_i;
float in1,in2,hid1,hid2,out;

float target;

float mu=0.3;
float mc=0.0;

void run_net(void);
void backprop(void);
float sigmoid(float);
float sigmoid_deriv(float);
float nicerand(void);
float randhalf(void);

int main ( int argc, char *argv[] ) {
  int i;

  if (argc!=2) {
    printf("num iters as arg!!!\n");
    exit(1);
  }
 
  srand(getpid()*time(0));

  wa11=nicerand();
  wa12=nicerand();
  wa21=nicerand();
  wa22=nicerand();
  wb1 =nicerand();
  wb2 =nicerand();
 
  /*wa11=0.060;
  wa12=0.071;
  wa21=0.027;
  wa22=0.013;
  wb1=-0.096;
  wb2=-0.015;*/

  //printf("Running for %d iterations...\n",atoi(argv[1]));

  for (i=0;i<atoi(argv[1]);i++) {

    /*in1=randhalf();
    in2=randhalf();
    target=in1+in2;
    run_net();
    backprop();*/
    //printf("%.3f (%.3f + %.3f = %.3f, got %.3f)\n",
    //       target-out,in1,in2,in1+in2,out);

    in1=0.0; in2=0.0;
    target=1.0;
    run_net();
    backprop();
    //printf("0,0 --> %.3f (e %.3f)\n",out,target-out);

    in1=1.0; in2=1.0;
    target=1.0;
    run_net();
    backprop();
    //printf("1,1 --> %.3f (e %.3f)\n",out,target-out);

    in1=1.0; in2=0.0;
    target=0.0;
    run_net();
    backprop();
    //printf("1,0 --> %.3f (e %.3f)\n",out,target-out);

    in1=0.0; in2=1.0;
    target=0.0;
    run_net();
    backprop();
    //printf("0,1 --> %.3f (e %.3f)\n",out,target-out);
  }

  in1=0.0; in2=0.0;
  target=1.0;
  run_net();
  printf("0,0 --> %.3f (e %.3f)\n",out,target-out); 

  in1=0.0; in2=1.0;
  target=0;
  run_net();
  printf("0,1 --> %.3f (e %.3f)\n",out,target-out);

  in1=1.0; in2=0.0;
  target=0;
  run_net();
  printf("1,0 --> %.3f (e %.3f)\n",out,target-out);

  in1=1.0; in2=1.0;
  target=1;
  run_net();
  printf("1,1 --> %.3f (e %.3f)\n",out,target-out);
  
  return 0;
}

void run_net ( void ) {

  //in1 and in2 must be set before this

  hid1_i=(in1*wa11)+(in2*wa21);
  hid2_i=(in1*wa12)+(in2*wa22);
  //printf("hid1_i=%.3f hid2_i=%.3f wa11=%.3f wa12=%.3f wa21=%.3f wa22=%.3f\n",
    //      hid1_i,hid2_i,wa11,wa12,wa21,wa22);
  hid1=sigmoid(hid1_i);
  hid2=sigmoid(hid2_i);

  //printf("out=%.3f out_i=%.3f hid1=%.3f hid2=%.3f wb1=%.3f wb2=%.3f\n",
  //        out,out_i,hid1,hid2,wb1,wb2);
  out_i=(hid1*wb1)+(hid2*wb2);
  out=sigmoid(out_i);

  //output is stored in out
}

void backprop ( void ) {

  float error=(target-out);

  //output layer
  grad_out=sigmoid_deriv(out)*error;

  dwb1=(mu*grad_out*hid1) + (mc*dowb1);
  dwb2=(mu*grad_out*hid2) + (mc*dowb2);

  dowb1=dwb1;
  dowb2=dwb2;

  wb1+=dwb1;
  wb2+=dwb2;

  //hid layer
  grad_hid1=sigmoid_deriv(hid1)*wb1*grad_out;
  grad_hid2=sigmoid_deriv(hid2)*wb2*grad_out;

  dwa11=(mu*grad_hid1*in1)+(mc*dowa11);
  dwa21=(mu*grad_hid1*in2)+(mc*dowa21);
  dwa12=(mu*grad_hid2*in1)+(mc*dowa12);
  dwa22=(mu*grad_hid2*in2)+(mc*dowa22);
  
  dowa11=dwa11;
  dowa21=dwa21;
  dowa12=dwa12;
  dowa22=dwa22;
  
  wa11+=dwa11;
  wa21+=dwa21;
  wa12+=dwa12;
  wa22+=dwa22;
}

float sigmoid( float x ) {
  return (1/(1+exp(-x)));
}

float sigmoid_deriv ( float x ) {
  return ((x)*(1.0-x));
}

//-1 to 1
float nicerand ( void ) {
  float q;
  q=((0.2*((float)(rand())/((float)(RAND_MAX))))-0.1);
  //printf("w: %.3f\n",q);
  return q;
}

float randhalf( void ) {
  return ((float)(rand())/((float)(RAND_MAX))/2.0);

}

