/*************************************************************************/
/*                                                                       */
/*                Centre for Speech Technology Research                  */
/*                     University of Edinburgh, UK                       */
/*                      Copyright (c) 1995,1996                          */
/*                        All Rights Reserved.                           */
/*                                                                       */
/*  Permission to use, copy, modify, distribute this software and its    */
/*  documentation for research, educational and individual use only, is  */
/*  hereby granted without fee, subject to the following conditions:     */
/*   1. The code must retain the above copyright notice, this list of    */
/*      conditions and the following disclaimer.                         */
/*   2. Any modifications must be clearly marked as such.                */
/*   3. Original authors' names are not deleted.                         */
/*  This software may not be used for commercial purposes without        */
/*  specific prior written permission from the authors.                  */
/*                                                                       */
/*  THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK        */
/*  DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING      */
/*  ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT   */
/*  SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE     */
/*  FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES    */
/*  WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN   */
/*  AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION,          */
/*  ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF       */
/*  THIS SOFTWARE.                                                       */
/*                                                                       */
/*************************************************************************/
/*             Author :  Paul Taylor                                     */
/*             Date   :  June 1996                                       */
/*-----------------------------------------------------------------------*/
/*                  HMM Class Source                                     */
/*                                                                       */
/*=======================================================================*/

#include <stdlib.h>
#include "HMM.h"
#include "EST.h"
//#include "Token.h"
#include "fstream.h"
#include "iostream.h"



HMM_Mixture::HMM_Mixture()
{
    vecsize = 0;
}

HMM_Mixture::HMM_Mixture(int n)
{
    vecsize = n;
    mean.resize(n);
    var.resize(n);
}

HMM_Mixture::HMM_Mixture(HMM_Mixture &m)
{
    vecsize = m.vecsize;
    mean = m.mean;
    var = m.var;
};

HMM_Mixture & HMM_Mixture::operator=(const HMM_Mixture &m)
{
    vecsize = m.vecsize;
    mean = m.mean;
    var = m.var;
    return *this;
}

HMM_State::HMM_State()
{
    mixture.resize(0);
}

HMM_State::HMM_State(int n)
{
    mixture.resize(n);
}

void HMM::clear(void)
{
   state.resize(0);
}

HMM_State::HMM_State(HMM_State &s)
{
//    mixture = s.mixture;

    mixture = s.mixture;
//    int i;
//    for (i = 0; i < num_mixtures; ++i)
//	mixture(i) = s.mixture(i);
};

HMM_State & HMM_State::operator=(const HMM_State &m)
{
    mixture = m.mixture;
    m_weight = m.m_weight;
    return *this;
}

HMM::HMM()
{
    state = 0;
    vecsize = 0;
    name = "";
}

HMM::HMM(int n, int v)
{
//    state = new HMM_State[n];
    state.resize(n);
    vecsize = v;
    name = "";
}

HMM & HMM::operator=(const HMM &h)
{
//    cout << "copy vecsize = " << h.vecsize << endl;
    init(h.state.n(), h.vecsize, h.num_streams);

    covkind = h.covkind;
    durkind = h.durkind;
    sampkind = h.sampkind;
    name = h.name;

    state = h.state;
    trans = h.trans;
    return *this;
}

void HMM::balls(void)
{
//here
}


void HMM::init(int n_states, int vsize, int n_streams)
{
//    state = new HMM_State[n_states];
    state.resize(n_states);
    vecsize = vsize;
    num_streams = n_streams;
    trans.resize(n_states, n_states);
}

void HMM_State::init(int n_mixes, int vsize)
{
    //mixture = new HMM_Mixture[n_mixes];
    mixture.resize(n_mixes);
    m_weight.resize(n_mixes);

    for (int i = 0; i < n_mixes; ++i)
	mixture(i).init(vsize);
}

void HMM_Mixture::init(int vsize)
{
    mean.resize(vsize);
    var.resize(vsize);
    vecsize = vsize;
}

int legit_sample_kind(EST_String name)
{
    (void) name;
    return 1;
}

int parse_key(EST_TokenStream &ts, EST_String key)
{
    EST_String s, name;
    s = "<" + key + ">";
    name = ts.peek().lstring();
    if (name != s)
    {
	cerr << "Expected start of " << s << " definition, got: "
	    << ts.peek().lstring() << endl;
	return -1;
    }
    ts.get();
    return 0;
}

int parse_optional_key(EST_TokenStream &ts, EST_String key)
{
    EST_String s, name;
    s = "<" + key + ">";
    name = ts.peek().lstring();
    if (name != s)
	return -1;

    ts.get();
    return 0;
}

int parse_int(EST_TokenStream &ts, EST_String key, int n)
{
    EST_String name;
    int v;

    name = ts.get().lstring();
    v = atoi(name);
//    cout << "read: " << v << " need " << n << endl;

    if ((n > 0) && (n != v))
    {
	cerr << "Expected " << key << " value of " << n << 
	    ", got: " << v <<endl;
	return -1;
    }
    return 0;
}

int parse_mixture(EST_TokenStream &ts, HMM_Mixture &mix)
{
    EST_String name;
    int i;

//    cout << "Parsing mixture\n";

    if (parse_key(ts, "mean") == -1)
	return -1;

    if (parse_int(ts, "mean", mix.vecsize) == -1)
	return -1;

//    cout << "vec size: " << mix.vecsize << endl;
    for (i = 0; i < mix.vecsize; ++i)
	mix.mean(i) = atof(ts.get().string());

    if (parse_key(ts, "variance") == -1)
	return -1;

    if (parse_int(ts, "variance", mix.vecsize) == -1)
	return -1;

    else
	for (i = 0; i < mix.vecsize; ++i)
	    mix.var(i) = atof(ts.get().string());

    if (parse_optional_key(ts, "gconst") == -1)
	return 0;
    else
	mix.gconst = atof(ts.get().string());

    return 0;
}	

int parse_transition_matrix(EST_TokenStream &ts, EST_FMatrix &t, int n_states)
{
//    cout << "hello\n";
    if (parse_key(ts, "transp") == -1)
	return -1;

    if (parse_int(ts, "transp", n_states) == -1)
	return -1;
    
    int i, j;

    for (i = 0; i < n_states; ++i)
	for (j = 0; j < n_states; ++j)
//	    cout << "trans " << ts.get().string() << endl;
	    t(i, j) = atof(ts.get().string());

    return 0;
}

int parse_state(EST_TokenStream &ts, HMM_State &s, int state_num, int vsize)
{
    EST_String name;
    int num_mix;

    if (parse_key(ts, "state") == -1)
	return -1;

    if (parse_int(ts, "state", state_num + 1) == -1)
	return -1;

    if (parse_optional_key(ts, "nummixes") == -1)
	num_mix =1;
    else
	num_mix = atoi(ts.get().lstring());

//    cout << "num mixes:" << num_mix << endl;
    
    s.init(num_mix, vsize);
    
    if (num_mix == 1)
	parse_mixture(ts, s.mixture(0));
    else
	for (int i = 0; i < num_mix; ++i)
	{
	    if (parse_key(ts, "mixture") == -1)
		return -1;

	    if (parse_int(ts, "mixture", i + 1) == -1)
		return -1;

	    s.m_weight(i) = atof(ts.get().string());

	    parse_mixture(ts, s.mixture(i));
	}	
    return 0;
}

int parse_global(EST_TokenStream &ts, int &n_streams, int &v_size,
		 EST_String &c_kind, EST_String &d_kind, EST_String &s_kind)
{
    EST_String name, next_token;

    name = ts.peek();
    next_token = downcase(name);

    while (next_token != "~h")
    {
	name = ts.get().lstring();
//	cout << "global option: " << name << endl;
	if (name == "<streaminfo>")
	{
	    n_streams = atoi(ts.get().string()); // number of streams
	    for (int i = 0; i < n_streams; ++i) // read one num for each stream
		ts.get().string();
	}
	else if (name == "<vecsize>")
	    v_size = atoi(ts.get().string());
	
	else if ((name == "<diagc>") || (name == "<fullc>")
		 || (name ==  "<xformc>"))
	    c_kind = name;
	
	else if ((name == "<nulld>") || (name == "<poissond>")
		 || (name == "<gammad>") || (name == "<gend>"))
	    d_kind = name;
	
	else if (legit_sample_kind(name))
	    s_kind = name;
	
	else
	{
	    cerr << "Unknown Option definition: " << name << endl;
	    return -1;
	    
	}
	name = ts.peek();
	next_token = downcase(name);
    }

//    cout << "VV: " << v_size << endl;
    return 0;
}

EST_read_status HMM::load(EST_String filename)
{
    EST_String c_token, next_token;
    EST_TokenStream ts;
    EST_Token t;
    int n_states, n_streams, v_size;
    n_states = n_streams = v_size = 0;
    
    if (((filename == "-") ? ts.open(cin) : ts.open(filename)) != 0)
    {
	cerr << "Can't open HMM input file " << filename << endl;
	return misc_read_error;
    }

    // Parse Global header
    if (parse_global(ts, n_streams, v_size, covkind, durkind, sampkind) == -1)
	return misc_read_error;

    if (ts.get().string() == "~h")
	name = ts.get().string();
    else
    {
	cerr << "Expected HMM macro name start (~h name)\n";
	return misc_read_error;
    }

    if (parse_key(ts, "beginhmm") == -1)
	return misc_read_error;

    // Parse Macro capability
    if (parse_optional_key(ts, "use") != -1)
    {
	cerr << "Macro reading Capability not implemented yet\n";
	return misc_read_error;
    }

    // Parse Options
    c_token = ts.peek();
    next_token = downcase(c_token);
    while (next_token != "<state>")
    {
	c_token = ts.get().lstring();
//	cout << "Option: " << c_token << endl;
	if (c_token == "<numstates>")
	    n_states = atoi(ts.get().string());
	else if (c_token == "<streaminfo>")
	    n_streams = atoi(ts.get().string());
	else if (c_token == "<vecsize>")
	    v_size = atoi(ts.get().string());
	
	else if ((c_token == "<diagc>") || (c_token == "<fullc>")
		 || (c_token ==  "<xformc>"))
	    covkind = c_token;
	
	else if ((c_token == "<nulld>") || (c_token == "<poissond>")
		 || (c_token == "<gammad>") || (c_token == "<gend>"))
	    durkind = c_token;
	
	else if (legit_sample_kind(c_token))
	    sampkind = c_token;
	
	else
	{
	    cerr << "Unknown Option definition: " << c_token << endl;
	    return misc_read_error;
	}
	c_token = ts.peek();
	next_token = downcase(c_token);
    }
    
    init(n_states, v_size, n_streams);
    
//    cout << "N states = " << n_states << endl;

    for (int i = 1; i < n_states - 1; ++i)
	parse_state(ts, state(i), i, v_size);

    if (parse_transition_matrix(ts, trans, n_states) == -1)
	return misc_read_error;
    
    ts.close();
    return format_ok;

}

EST_read_status HMM::load_portion(EST_TokenStream &ts, int v_size, int n_streams)
{
    EST_String c_token, next_token;
    EST_Token t;
    int n_states=1;

//    cout << "V states = " << v_size << endl;

    if (parse_key(ts, "beginhmm") == -1)
	return misc_read_error;

    // Parse Macro capability
    if (parse_optional_key(ts, "use") != -1)
    {
	cerr << "Macro reading Capability not implemented yet\n";
	return misc_read_error;
    }

    // Parse Options
    c_token = ts.peek();
    next_token = downcase(c_token);
    while (next_token != "<state>")
    {
	c_token = ts.get().lstring();
//	cout << "model option: " << c_token << endl;
	if (c_token == "<numstates>")
	    n_states = atoi(ts.get().string());
	else if (c_token == "<streaminfo>")
	    n_streams = atoi(ts.get().string());
	else if (c_token == "<vecsize>")
	    v_size = atoi(ts.get().string());
	
	else if ((c_token == "<diagc>") || (c_token == "<fullc>")
		 || (c_token ==  "<xformc>"))
	    covkind = c_token;
	
	else if ((c_token == "<nulld>") || (c_token == "<poissond>")
		 || (c_token == "<gammad>") || (c_token == "<gend>"))
	    durkind = c_token;
	
	else if (legit_sample_kind(c_token))
	    sampkind = c_token;
	
	else
	{
	    cerr << "Unknown Option definition: " << c_token << endl;
	    return misc_read_error;
	    
	}
	c_token = ts.peek();
	next_token = downcase(c_token);
    }
    
    init(n_states, v_size, n_streams);

    for (int i = 1; i < n_states - 1; ++i)
	parse_state(ts, state(i), i, v_size);

    if (parse_transition_matrix(ts, trans, n_states) == -1)
	return misc_read_error;

    if (parse_key(ts, "endhmm") == -1)
	return misc_read_error;
    return format_ok;
}

EST_write_status HMM::save(EST_String filename)
{
    ostream *outf;
    if (filename == "-")
	outf = &cout;
    else
	outf = new ofstream(filename);

    if (outf == 0)
	return misc_write_error;
    
    *outf << *this;

    return write_ok;
}


ostream& operator << (ostream& s, const HMM_Mixture &mix)
{
    s << "<mean> " << mix.vecsize << endl;
    int i;
    for (i = 0; i < mix.vecsize -1; ++i)
	s << mix.mean(i) << " ";
    s << mix.mean(i) << endl;

    s << "<variance> " << mix.vecsize << endl;
    for (i = 0; i < mix.vecsize -1; ++i)
	s << mix.var(i) << " ";
    s << mix.var(i) << endl;

    // only put in when we know how to calculate this
    //    s << "<gconst> " << endl;
    //    s << mix.gconst << endl;

    return s;
}


ostream& operator << (ostream& s, const HMM_State &st)
{
 s << "<nummixes> " << st.mixture.n() << endl;
    for (int i = 0; i < st.mixture.n(); ++i)
    {
	s << "<mixture> " << i + 1 << " " << st.m_weight(i) << endl;
	s << st.mixture(i);
    }

    return s;

}

ostream& operator << (ostream& s, const HMM &model)
{
    s << "<beginhmm>" << endl;
    s << "<numstates> " << model.state.n() << endl;

    for (int i = 1; i < model.state.n() - 1; ++i)
    {
	s << "<state> " << i + 1 << endl;
	s << model.state(i);
    }
    s << "<transp> " << model.state.n() << endl;
    s << model.trans;
    s << "<endhmm>" << endl;

    return s;
}



int operator !=(HMM_Mixture &a, HMM_Mixture &b)
{
    (void) a;
    (void) b;
    return 0;
}

int operator !=(HMM_State &a, HMM_State &b)
{
    (void) a;
    (void) b;
    return 0;
}
 
