const char *help = "\
LittleSVMTorch (c) Trebolloc & Co 2001\n\
\n\
This program will train a SVM with a gaussian kernel in\n\
classification or regression.\n";

#include "FileDataSet.h"
#include "MseCriterion.h"
#include "MseMeasurer.h"
#include "ClassMeasurer.h"
#include "TwoClassFormat.h"
#include "CmdLine.h"
#include "SVMClassification.h"
#include "SVMRegression.h"
#include "Kernel.h"
#include "QCTrainer.h"
#include "SVMCache.h"

using namespace Torch;

int main(int argc, char **argv)
{
  char *model_file, *test_file;
  char *file;

  int max_load;
  bool regression;
  int k_fold;
  int the_seed;
  int n_inputs;

  real std;
  real cache_size;
  real eps_tube;
  real accuracy;
  int iter_shrink;

  //=================== The command-line ==========================

  // Construct the command line
  CmdLine cmd;

  // Put the help line at the beginning
  cmd.info(help);

  // Ask for arguments
  cmd.addText("\nArguments:");
  cmd.addSCmdArg("file", &file, "the *train* file");
  cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");

  // Propose some options
  cmd.addText("\nModel Options:");
  cmd.addRCmdOption("-std", &std, 10., "the std parameter in the gaussian kernel");
  cmd.addBCmdOption("-rm", &regression, false, "regression mode");
  cmd.addRCmdOption("-eps", &eps_tube, 0.7, "eps tube in regression");

  cmd.addText("\nLearning Options:");
  cmd.addRCmdOption("-e", &accuracy, 0.01, "end accuracy");
  cmd.addRCmdOption("-m", &cache_size, 50., "cache size in Mo");
  cmd.addICmdOption("-h", &iter_shrink, 100, "minimal number of iterations before shrinking");

  cmd.addText("\nMisc Options:");
  cmd.addICmdOption("-seed", &the_seed, -1, "the random seed");
  cmd.addICmdOption("-Kfold", &k_fold, -1, "number of subsets for K-fold cross-validation");
  cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load");
  cmd.addSCmdOption("-model", &model_file, "", "model file to load (if there is a -test option)/save (else)");
  cmd.addSCmdOption("-test", &test_file, "", "the test file");

  // Read the command line
  cmd.read(argc, argv);

  // If the user didn't give any random seed,
  // generate a random random seed...
  if(the_seed == -1)
    seed();
  else
    manual_seed((long)the_seed);

  if( strcmp(test_file, "") && (k_fold > 0) )
     error("Please do not provide a test file when you want to do K-fold...");

  //=================== Training DataSet ==========================

  // Create the training dataset (normalize inputs)
  FileDataSet data(file, n_inputs, 1, false, max_load);
//  data.setBOption("normalize inputs", true);
  data.init();

  //=================== Create the SVM... =========================

  // The Kernel
  GaussianKernel kernel(&data);
  kernel.setROption("gamma", 1./(std*std));
  kernel.init();

  // The SVM
  SVM *svm;

  if(regression)
  {
    svm = new SVMRegression(&kernel);
    svm->setROption("eps regression", eps_tube);
  }
  else
    svm = new SVMClassification(&kernel);
  svm->init();

  // The SVM-cache
  SVMCache *cache;

  if(regression)
    cache = new SVMCacheRegression((SVMRegression *)svm, cache_size);
  else
    cache = new SVMCacheClassification((SVMClassification *)svm, cache_size);


  //=================== Test DataSet & Measurers... ===============

  // The list of measurers...
  List *measurers = NULL;

  // The class format
  ClassFormat *class_format = NULL;
  if(!regression)
     class_format = new TwoClassFormat(&data);

  // The test set...
  FileDataSet *test_data = NULL;
  MseMeasurer *test_mse_meas = NULL;
  ClassMeasurer *test_class_meas = NULL;

  // Create a test set, if any
  if(strcmp(test_file, ""))
  {
    // Load the test set
    test_data = new FileDataSet(test_file, n_inputs, 1);
    test_data->init();
//    test_data->normalizeUsingDataSet(&data);

    // Create a MSE measurer and an error class measurer
    // on the test dataset (if we are not in regression)
    test_mse_meas = new MseMeasurer(svm->outputs, test_data, "the_test_mse");
    test_mse_meas->init();
    addToList(&measurers, 1, test_mse_meas);

    if(!regression)
    {
      test_class_meas = new ClassMeasurer(svm->outputs, test_data, class_format, "the_test_class_err");
      test_class_meas->init();
      addToList(&measurers, 1, test_class_meas);
    }
  }

  // Measurers on the training dataset
  MseMeasurer *mse_meas = new MseMeasurer(svm->outputs, &data, "the_mse");
  mse_meas->init();
  addToList(&measurers, 1, mse_meas);

  ClassMeasurer *class_meas = NULL;
  if(!regression)
  {
    class_meas = new ClassMeasurer(svm->outputs, &data, class_format, "the_class_err");
    class_meas->init();
    addToList(&measurers, 1, class_meas);
  }

  //=================== The Trainer ===============================
  
  QCTrainer trainer(svm, &data, cache);
  trainer.setROption("end accuracy", accuracy);
  trainer.setIOption("iter shrink", iter_shrink);

  //=================== Let's go... ===============================

  // If the user provides a previously trained model,
  // and a test dataset, test it...
  if( strcmp(model_file, "") && strcmp(test_file, ""))
  {
    trainer.load(model_file);
    trainer.test(measurers);
  }

  // ...else...
  else
  {
    // If the user provides a number for the K-fold validation,
    // do a K-fold validation
    if(k_fold > 0)
      trainer.crossValidate(k_fold, NULL, measurers);

    // Else, train the model
    else
    {
      trainer.train(NULL);
      message("%d supports vectors", svm->n_support_vectors);

      // Save the model if the user provides a name for that
      if( strcmp(model_file, "") )
         trainer.save(model_file);
    }
  }

  //=================== Quit... ===================================
  if(strcmp(test_file, ""))
  {
    delete test_data;
    delete test_mse_meas;
    if(!regression)
      delete test_class_meas;
  }

  delete mse_meas;
  if(!regression)
  {
    delete class_meas;
    delete class_format;
  }

  delete svm;
  delete cache;

  freeList(&measurers);

  return(0);
}
