// Copyright (C) 2002 Samy Bengio (bengio@idiap.ch)
//                
//
// This file is part of Torch. Release II.
// [The Ultimate Machine Learning Library]
//
// Torch is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 2 of the License, or
// (at your option) any later version.
//
// Torch is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Torch; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


#include "DiagonalGMM.h"
#include "log_add.h"
#include "random.h"

namespace Torch {

DiagonalGMM::DiagonalGMM(int n_observations_, int n_gaussians_, real* var_threshold_, real prior_weights_) : Distribution()
{
  n_observations = n_observations_;
  n_gaussians = n_gaussians_;
  prior_weights = prior_weights_;
  var_threshold = (real*)xalloc(sizeof(real)*n_observations);
  if (!var_threshold_) {
    for (int i=0;i<n_observations;i++)
      var_threshold[i] = 1e-10; 
  } else {
    for (int i=0;i<n_observations;i++)
      var_threshold[i] = var_threshold_[i];
  }
  initial_kmeans_trainer = NULL;
  initial_kmeans_trainer_measurers = NULL;
  initial_params = NULL;
  initial_file = NULL;
  addOption("initial kmeans trainer",sizeof(EMTrainer*),&initial_kmeans_trainer,"initial kmeans trainer", true);
  addOption("initial kmeans trainer measurers",sizeof(List*),&initial_kmeans_trainer_measurers,"initial kmeans trainer measurers");
  addOption("initial params",sizeof(List*),&initial_params,"initial params");
  addOption("initial file",sizeof(char*),&initial_file,"initial file");
}

void DiagonalGMM::allocateMemory()
{
  max_n_frames = 1;
  n_params = numberOfParams();
  addToList(&params,n_params,(real*)xalloc(sizeof(real)*n_params));
  addToList(&der_params,n_params,(real*)xalloc(sizeof(real)*n_params));
  addToList(&outputs,n_outputs,(real*)xalloc(sizeof(real)*n_outputs));
  real* p = (real*)params->ptr;
  real* dp = (real*)der_params->ptr;
  log_weights = p;
  dlog_weights = dp;
  p += n_gaussians;
  dp += n_gaussians;
  log_probabilities = (real*)xalloc(sizeof(real)*max_n_frames);
  log_probabilities_g = (real**)xalloc(sizeof(real*)*max_n_frames);
  means = (real**)xalloc(sizeof(real*)*n_gaussians);
  dmeans = (real**)xalloc(sizeof(real*)*n_gaussians);
  var = (real**)xalloc(sizeof(real*)*n_gaussians);
  dvar = (real**)xalloc(sizeof(real*)*n_gaussians);
  means_acc = (real**)xalloc(sizeof(real*)*n_gaussians);
  var_acc = (real**)xalloc(sizeof(real*)*n_gaussians);
  weights_acc = (real*)xalloc(sizeof(real)*n_gaussians);
  minus_half_over_var = (real**)xalloc(sizeof(real*)*n_gaussians);
  for (int i=0;i<max_n_frames;i++)
    log_probabilities_g[i] = (real*)xalloc(sizeof(real)*n_gaussians);
  for (int i=0;i<n_gaussians;i++) {
    means[i] = p;
    dmeans[i] = dp;
    p += n_observations;
    dp += n_observations;
    var[i] = p;
    dvar[i] = dp;
    p += n_observations;
    dp += n_observations;
    means_acc[i] = (real*)xalloc(sizeof(real)*n_observations);
    var_acc[i] = (real*)xalloc(sizeof(real)*n_observations);
    minus_half_over_var[i] = (real*)xalloc(sizeof(real)*n_observations);
  }
  sum_log_var_plus_n_obs_log_2_pi = 
    (real*)xalloc(sizeof(real)*n_gaussians);
}

void DiagonalGMM::freeMemory()
{
	freeList(&outputs,true);
	freeList(&params,true);
	freeList(&der_params,true);
	if(means_acc){
		for (int i=0;i<n_gaussians;i++)
			free(means_acc[i]);
		free(means_acc);
		means_acc = NULL;
	}
	if(var_acc){
		for (int i=0;i<n_gaussians;i++)
			free(var_acc[i]);
		free(var_acc);
		var_acc = NULL;
	}

	if(minus_half_over_var){
		for (int i=0;i<n_gaussians;i++)
			free(minus_half_over_var[i]);
		free(minus_half_over_var);
		minus_half_over_var = NULL;
	}
	if(log_probabilities_g){
		for (int i=0;i<max_n_frames;i++)
			free(log_probabilities_g[i]);
		free(log_probabilities_g);
		log_probabilities_g = NULL;
	}
	if(log_probabilities){
		free(log_probabilities);
		log_probabilities = NULL;
	}
	if(means){
		free(means);
		means = NULL;
	}
	if(dmeans){
		free(dmeans);
		dmeans = NULL;
	}
	if(var){
		free(var);
		var = NULL;
	}
	if(dvar){
		free(dvar);
		dvar = NULL;
	}
	if(means_acc){
		free(means_acc);
		means_acc = NULL;
	}
	if(var_acc){
		free(var_acc);
		var_acc = NULL;
	}
	if(weights_acc){
		free(weights_acc);
		weights_acc = NULL;
	}
	if(sum_log_var_plus_n_obs_log_2_pi){
		free(sum_log_var_plus_n_obs_log_2_pi);
		sum_log_var_plus_n_obs_log_2_pi = NULL;
	}
	if(minus_half_over_var){
		free(minus_half_over_var);
		minus_half_over_var = NULL;
	}
}

int DiagonalGMM::numberOfParams()
{
  return (n_observations * n_gaussians * 2) // number of means and var
        + n_gaussians; // number of weights;
}

void DiagonalGMM::reset()
{
  // here, initialize the parameters somehow...

  if (initial_kmeans_trainer) {
    initial_kmeans_trainer->machine->reset();
    initial_kmeans_trainer->train(initial_kmeans_trainer_measurers);
    copyList(params,initial_kmeans_trainer->distribution->params);
  } else if (initial_params) {
    copyList(params,initial_params);
  } else if (initial_file) {
    load(initial_file);
  } else {
    // initialize randomly
    // first the weights
    real sum = 0.;
    for (int i=0;i<n_gaussians;i++) {
      log_weights[i] = bounded_uniform(0.1,1);
      sum += log_weights[i];
    }
    for (int i=0;i<n_gaussians;i++) {
      log_weights[i] = log(log_weights[i]/sum);
    }

    // then the means and variances
    for (int i=0;i<n_gaussians;i++) {
      for (int j=0;j<n_observations;j++) {
        means[i][j] = bounded_uniform(0,1);
        var[i][j] = bounded_uniform(var_threshold[j],var_threshold[j]*10);
      }
    }
  }
}

void DiagonalGMM::eMSequenceInitialize(List* inputs)
{
  if (!inputs)
    return;
  SeqExample* ex = (SeqExample*)inputs->ptr;
  if (ex->n_real_frames > max_n_frames) {
    int old_max = max_n_frames;
    max_n_frames = ex->n_real_frames;
    log_probabilities_g = (real**)xrealloc(log_probabilities_g,sizeof(real*)*max_n_frames);
    log_probabilities = (real*)xrealloc(log_probabilities,sizeof(real)*max_n_frames);
    for (int i=old_max;i<max_n_frames;i++) {
      log_probabilities_g[i] = (real*)xalloc(sizeof(real)*n_gaussians);
    }
  }
}

void DiagonalGMM::display()
{
  for(int i=0;i<n_gaussians;i++){
    printf("Mixture %d %.3g\n",i+1,exp(log_weights[i]));
    printf("Mean\n");
    for(int j=0;j<n_observations;j++){
       printf("%.3g ",means[i][j]);
    }
    printf("\nVar\n");
    for(int j=0;j<n_observations;j++){
       printf("%.3g ",var[i][j]);
    }
    printf("\n"); 
 }
}

void DiagonalGMM::sequenceInitialize(List* inputs)
{
  // initialize the accumulators to 0 and compute pre-computed value
  eMSequenceInitialize(inputs);
  real *dlw = dlog_weights;
  for (int i=0;i<n_gaussians;i++) {
    real *sum = &sum_log_var_plus_n_obs_log_2_pi[i];
    *sum = n_observations * LOG_2_PI;
    real *mh_i = minus_half_over_var[i];
    *dlw++ = 0;
    real *v = var[i];
    real *vt = var_threshold;
    real *dmeans_i = dmeans[i];
    real *dvar_i = dvar[i];
    for (int j=0;j<n_observations;j++,vt++) {
      *dmeans_i++ = 0;
      *dvar_i++ = 0;
      if (*v < *vt)
        *v = *vt;
      *mh_i++ = -0.5 / *v;
      *sum += log(*v++);
    }
    *sum *= -0.5;
  }
}

void DiagonalGMM::generate(SeqDataSet* data,int n_){
  data->n_observations = n_observations;
  data->tot_n_frames = n_;
  SeqExample* ex = data->examples;
  ex->n_frames = data->tot_n_frames;
  ex->inputs = NULL;
  ex->seqtargets = NULL;
  ex->n_real_frames = data->tot_n_frames;
  ex->selected_frames = NULL;
  ex->n_seqtargets = 0;
  ex->seqtargets = NULL;
  ex->n_alignments = 0;
  ex->alignment = NULL;
  ex->name = NULL;
  ex->observations = (real**)xalloc(sizeof(real*)*data->tot_n_frames);
  real** obs = ex->observations;
  for(int i=0;i<data->tot_n_frames;i++){
    *obs = (real*)xalloc(sizeof(real)*data->n_observations);
    generateObservation(*obs++);
  }
  data->init();
}

void DiagonalGMM::generateObservation(real* observation)
{
	//generate one observation
	real v_tot,v_partial;
	real noise_fact = 1.0;
	/* choisir une gaussienne */
	v_tot = uniform();
	v_partial = 0.;
	real* lw = log_weights; 
	int j;
	for (j=0;j<n_gaussians;j++) {
	  v_partial += exp(*lw++);
	  if (v_partial > v_tot) break;
	}
	if(j>=n_gaussians)
           j = n_gaussians - 1;
	real* v = var[j];
	real* m = means[j];
	real* obs = observation;

	for (int i=0;i<n_observations;i++) {
		*obs++ = gaussian_mu_sigma(*m++,noise_fact * sqrt(*v++));
	}
}


real DiagonalGMM::frameLogProbabilityOneGaussian(real *observations, real *inputs, int g)
{
  real* means_g = means[g];
  real* mh_g = minus_half_over_var[g];
  real sum_xmu = 0.;
  real *x = observations;
  for(int j = 0; j < n_observations; j++) {
    real xmu = (*x++ - *means_g++);
    sum_xmu += xmu*xmu * *mh_g++;
  }
  real lp = sum_xmu + sum_log_var_plus_n_obs_log_2_pi[g];
  return lp;
}

real DiagonalGMM::frameLogProbability(real *observations, real *inputs, int t)
{
  real *p_log_w = log_weights;
  real *lpg = log_probabilities_g[t];
  real log_prob = LOG_ZERO;
  for (int i=0;i<n_gaussians;i++) {
    *lpg = frameLogProbabilityOneGaussian(observations,inputs,i);
    log_prob = log_add(log_prob, *lpg++ + *p_log_w++);
  }
  log_probabilities[t] = log_prob;
  return log_prob;
}

void DiagonalGMM::frameEMAccPosteriors(real *observations, real log_posterior, real *inputs, int t)
{
  real log_prob = log_probabilities[t];
  real *p_weights_acc = weights_acc;
  real *lp_i = log_probabilities_g[t];
  real *log_w_i = log_weights;
  for (int i=0;i<n_gaussians;i++) {
    real post_i = exp(log_posterior + *log_w_i++ + *lp_i++ - log_prob);
    *p_weights_acc++ += post_i;
    real* means_acc_i = means_acc[i];
    real* var_acc_i = var_acc[i];
    real *x = observations;
    for(int j = 0; j < n_observations; j++) {
      *var_acc_i++ += post_i * *x * *x;
      *means_acc_i++ += post_i * *x++;
    }
  }
}

void DiagonalGMM::eMUpdate()
{
  // first the gaussians
  real* p_weights_acc = weights_acc;
  for (int i=0;i<n_gaussians;i++,p_weights_acc++) {
    if (*p_weights_acc == 0) {
      warning("Gaussian %d of GMM is not used in EM",i);
    } else {
      real* p_means_i = means[i];
      real* p_var_i = var[i];
      real* p_means_acc_i = means_acc[i];
      real* p_var_acc_i = var_acc[i];
      for (int j=0;j<n_observations;j++) {
        *p_means_i = *p_means_acc_i++ / *p_weights_acc;
        real v = *p_var_acc_i++ / *p_weights_acc - *p_means_i * *p_means_i++;
        *p_var_i++ = v >= var_threshold[j] ? v : var_threshold[j];
      }
    }
  }
  // then the weights
  real sum_weights_acc = 0;
  p_weights_acc = weights_acc;
  for (int i=0;i<n_gaussians;i++)
    sum_weights_acc += *p_weights_acc++;
  real *p_log_weights = log_weights;
  real log_sum = log(sum_weights_acc);
  p_weights_acc = weights_acc;
  for (int i=0;i<n_gaussians;i++)
    *p_log_weights++ = log(*p_weights_acc++) - log_sum;
}

void DiagonalGMM::eMIterInitialize()
{
  // initialize the accumulators to 0 and compute pre-computed value
  for (int i=0;i<n_gaussians;i++) {
    real *pm = means_acc[i];
    real *ps = var_acc[i];
    real *v = var[i];
    real *sum = &sum_log_var_plus_n_obs_log_2_pi[i];
    *sum = n_observations * LOG_2_PI;
    real *mh_i = minus_half_over_var[i];
    for (int j=0;j<n_observations;j++) {
      *pm++ = 0.;
      *ps++ = 0.;
      *mh_i++ = -0.5 / *v;
      *sum += log(*v++);
    }
    *sum *= -0.5;
    weights_acc[i] = prior_weights;
  }
}

void DiagonalGMM::iterInitialize()
{
}

void DiagonalGMM::frameBackward(real *observations, real *alpha, real *inputs, int t)
{
  real log_prob = log_probabilities[t];
  real *lp_i = log_probabilities_g[t];
  real *lw = log_weights;
  real* dlw = dlog_weights;
  for (int i=0;i<n_gaussians;i++,lw++,lp_i++) {
    real post_i =  - *alpha * exp(*lw + *lp_i - log_prob);
    *dlw++ += post_i;
    real *dlw2 = dlog_weights;
    real *lw2 = log_weights;
    for (int j=0;j<n_gaussians;j++)
      *dlw2++ -= post_i * exp(*lw2++);
    real* obs = observations;
    real* means_i = means[i];
    real* dmeans_i = dmeans[i];
    real* var_i = var[i];
    real* dvar_i = dvar[i];
    for (int j=0;j<n_observations;j++,var_i++,obs++,means_i++,dmeans_i++,dvar_i++) {
      real xmuvar = (*obs - *means_i) / *var_i;
      real dm = post_i * 2. * xmuvar;
      *dmeans_i += dm;
      *dvar_i += post_i * 0.5 * (xmuvar*xmuvar - 1./ *var_i);
    }
  }
}

void DiagonalGMM::frameExpectation(real *observations, real *inputs, int t)
{
  real *obs = observations;
  for (int i=0;i<n_observations;i++) {
    *obs++ = 0;
  }
  real *lw = log_weights;
  for (int i=0;i<n_gaussians;i++) {
    obs = observations;
    real *means_i = means[i];
    real w = exp(*lw++);
    for (int j=0;j<n_observations;j++) {
      *obs++ += w * *means_i++;
    }
  }
}

void DiagonalGMM::setNGaussians(int n_gaussians_)
{
  n_gaussians = n_gaussians_;
}

DiagonalGMM::~DiagonalGMM()
{
  free(var_threshold);
  freeMemory();
}

}

