/* -*-c++-*- OpenSceneGraph - Copyright (C) 1998-2006 Robert Osfield 
 *
 * This library is open source and may be redistributed and/or modified under  
 * the terms of the OpenSceneGraph Public License (OSGPL) version 0.0 or 
 * (at your option) any later version.  The full license is in LICENSE file
 * included with this distribution, and on the openscenegraph.org website.
 * 
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 * OpenSceneGraph Public License for more details.
*/

#ifndef OSGUTIL_STATISTICS
#define OSGUTIL_STATISTICS 1

#include <osg/PrimitiveSet>
#include <osg/Drawable>
#include <osg/NodeVisitor>
#include <osg/Geode>
#include <osg/LOD>
#include <osg/Switch>
#include <osg/Geometry>
#include <osg/Transform>

#include <map>
#include <set>
#include <ostream>

namespace osgUtil {

/**
 * Statistics base class. Used to extract primitive information from 
 * the renderBin(s).  Add a case of getStats(osgUtil::Statistics *stat)
 * for any new drawable (or drawable derived class) that you generate 
 * (eg see Geometry.cpp).  There are 20 types of drawable counted - actually only
 * 14 cases can occur in reality.  these represent sets of GL_POINTS, GL_LINES
 * GL_LINESTRIPS, LOOPS, TRIANGLES, TRI-fans, tristrips, quads, quadstrips etc
 * The number of triangles rendered is inferred:
 * each triangle = 1 triangle (number of vertices/3)
 * each quad = 2 triangles (nverts/2)
 * each trifan or tristrip = (length-2) triangles and so on.
 */

class Statistics : public osg::PrimitiveFunctor
{
    public:

        typedef std::pair<unsigned int,unsigned int>    PrimitivePair;
        typedef std::map<GLenum,PrimitivePair>          PrimitiveValueMap;
        typedef std::map<GLenum, unsigned int>          PrimitiveCountMap;


        Statistics()
        {
            reset();
        };

        enum StatsType
        {
            STAT_NONE, // default
            STAT_FRAMERATE, 
            STAT_GRAPHS,
            STAT_PRIMS, 
            STAT_PRIMSPERVIEW, 
            STAT_PRIMSPERBIN,
            STAT_DC,
            STAT_RESTART // hint to restart the stats
        };
        
        void reset()
        {
            numDrawables=0;
            nummat=0;
            depth=0;
            stattype=STAT_NONE;
            nlights=0;
            nbins=0;
            nimpostor=0;
            
            _vertexCount=0;
            _primitiveCount.clear();            
            
            _currentPrimitiveFunctorMode=0;

            _primitives_count.clear();
            _total_primitives_count=0;
            _number_of_vertexes=0;
        }

        void setType(StatsType t) {stattype=t;}
        
        virtual void setVertexArray(unsigned int count,const osg::Vec3*) { _vertexCount += count; }
        virtual void setVertexArray(unsigned int count,const osg::Vec2*) { _vertexCount += count; }
        virtual void setVertexArray(unsigned int count,const osg::Vec4*) { _vertexCount += count; }

        virtual void drawArrays(GLenum mode,GLint,GLsizei count) 
        { 
            PrimitivePair& prim = _primitiveCount[mode]; 
            ++prim.first; 
            prim.second+=count; 
            _primitives_count[mode] += _calculate_primitives_number_by_mode(mode, count);
        } 
        virtual void drawElements(GLenum mode,GLsizei count,const GLubyte*) 
        { 
            PrimitivePair& prim = _primitiveCount[mode]; 
            ++prim.first; 
            prim.second+=count; 
            _primitives_count[mode] += _calculate_primitives_number_by_mode(mode, count);
        } 
        virtual void drawElements(GLenum mode,GLsizei count,const GLushort*)
        { 
            PrimitivePair& prim = _primitiveCount[mode]; 
            ++prim.first; 
            prim.second+=count; 
            _primitives_count[mode] += _calculate_primitives_number_by_mode(mode, count);
        } 
        virtual void drawElements(GLenum mode,GLsizei count,const GLuint*)
        { 
            PrimitivePair& prim = _primitiveCount[mode]; 
            ++prim.first; 
            prim.second+=count; 
            _primitives_count[mode] += _calculate_primitives_number_by_mode(mode, count);
        } 

        virtual void begin(GLenum mode) 
        { 
            _currentPrimitiveFunctorMode=mode; 
            PrimitivePair& prim = _primitiveCount[mode]; 
            ++prim.first; 
            _number_of_vertexes = 0;
        }

        inline void vertex() 
        { 
            PrimitivePair& prim = _primitiveCount[_currentPrimitiveFunctorMode]; 
            ++prim.second; 
           _number_of_vertexes++;
        }
        virtual void vertex(float,float,float) { vertex(); }
        virtual void vertex(const osg::Vec3&)  { vertex(); }
        virtual void vertex(const osg::Vec2&)  { vertex(); }
        virtual void vertex(const osg::Vec4&)  { vertex(); }
        virtual void vertex(float,float)   { vertex(); }
        virtual void vertex(float,float,float,float)  { vertex(); }

        virtual void end() 
        {
          _primitives_count[_currentPrimitiveFunctorMode] += 
            _calculate_primitives_number_by_mode(_currentPrimitiveFunctorMode, _number_of_vertexes);
            
           _vertexCount += _number_of_vertexes;
        }
        
        void addDrawable() { numDrawables++;}
        void addMatrix() { nummat++;}
        void addLight(int np) { nlights+=np;}
        void addImpostor(int np) { nimpostor+= np; }
        inline int getBins() { return nbins;}
        void setDepth(int d) { depth=d; }
        void addBins(int np) { nbins+= np; }

        void setBinNo(int n) { _binNo=n;}
        
        void add(const Statistics& stats)
        {
            numDrawables += stats.numDrawables;
            nummat += stats.nummat;
            depth += stats.depth;
            nlights += stats.nlights;
            nbins += stats.nbins;
            nimpostor += stats.nimpostor;
            
            _vertexCount += stats._vertexCount;
            // _primitiveCount += stats._primitiveCount;   
            for(PrimitiveValueMap::const_iterator pitr = stats._primitiveCount.begin();
                pitr != stats._primitiveCount.end();
                ++pitr)
            {
                _primitiveCount[pitr->first].first += pitr->second.first;
                _primitiveCount[pitr->first].second += pitr->second.second;
            }
            
            _currentPrimitiveFunctorMode += stats._currentPrimitiveFunctorMode;

            for(PrimitiveCountMap::const_iterator citr = stats._primitives_count.begin();
                citr != stats._primitives_count.end();
                ++citr)
            {
                _primitives_count[citr->first] += citr->second;
            }

            _total_primitives_count += stats._total_primitives_count;
            _number_of_vertexes += stats._number_of_vertexes;
        }
                
    public:
                
        PrimitiveCountMap::iterator GetPrimitivesBegin() { return _primitives_count.begin(); }
        PrimitiveCountMap::iterator GetPrimitivesEnd() { return _primitives_count.end(); }

        int numDrawables, nummat, nbins;
        int nlights;
        int depth; // depth into bins - eg 1.1,1.2,1.3 etc
        int _binNo;
        StatsType stattype;
        int nimpostor; // number of impostors rendered
        
        unsigned int        _vertexCount;
        PrimitiveValueMap    _primitiveCount;
        GLenum              _currentPrimitiveFunctorMode;

    private:
        PrimitiveCountMap                     _primitives_count;

        unsigned int                         _total_primitives_count;
        unsigned int                         _number_of_vertexes;

        inline unsigned int _calculate_primitives_number_by_mode(GLenum, GLsizei);
};

inline unsigned int Statistics::_calculate_primitives_number_by_mode(GLenum mode, GLsizei count)
{
  switch (mode)
    {
    case GL_POINTS: 
    case GL_LINE_LOOP:
    case GL_POLYGON:  return count; 
    case GL_LINES: return count / 2; 
    case GL_LINE_STRIP: return count - 1; 
    case GL_TRIANGLES: return count / 3; 
    case GL_TRIANGLE_STRIP:
    case GL_TRIANGLE_FAN: return count - 2; 
    case GL_QUADS: return count / 4; 
    case GL_QUAD_STRIP: return count / 2 - 1; 
    default: return 0;
    }
}

/** StatsVisitor for collecting statistics about scene graph.*/
class StatsVisitor : public osg::NodeVisitor
{
public:

    typedef std::set<osg::Node*> NodeSet;
    typedef std::set<osg::Drawable*> DrawableSet;
    typedef std::set<osg::StateSet*> StateSetSet;

    StatsVisitor():
        osg::NodeVisitor(osg::NodeVisitor::TRAVERSE_ALL_CHILDREN),
        _numInstancedGroup(0),
        _numInstancedSwitch(0),
        _numInstancedLOD(0),
        _numInstancedTransform(0),
        _numInstancedGeode(0),
        _numInstancedDrawable(0),
        _numInstancedGeometry(0),
        _numInstancedStateSet(0) {}
        
    void reset()
    {
        _numInstancedGroup = 0;
        _numInstancedSwitch = 0;
        _numInstancedLOD = 0;
        _numInstancedTransform = 0;
        _numInstancedGeode = 0;
        _numInstancedDrawable = 0;
        _numInstancedGeometry = 0;
        _numInstancedStateSet = 0;

        _groupSet.clear();
        _transformSet.clear();
        _lodSet.clear();
        _switchSet.clear();
        _geodeSet.clear();
        _drawableSet.clear();
        _geometrySet.clear();
        _statesetSet.clear();

        _uniqueStats.reset();
        _instancedStats.reset();
    }    
    
    void apply(osg::Node& node)
    {
        if (node.getStateSet()) 
        {
            ++_numInstancedStateSet;
            _statesetSet.insert(node.getStateSet());
        }
        traverse(node);
    }

    void apply(osg::Group& node)
    {
        if (node.getStateSet()) 
        {
            ++_numInstancedStateSet;
            _statesetSet.insert(node.getStateSet());
        }
    
        ++_numInstancedGroup;
        _groupSet.insert(&node);
        traverse(node);
    }

    void apply(osg::Transform& node)
    {
        if (node.getStateSet()) 
        {
            ++_numInstancedStateSet;
            _statesetSet.insert(node.getStateSet());
        }
    
        ++_numInstancedTransform;
        _transformSet.insert(&node);
        traverse(node);
    }

    void apply(osg::LOD& node)
    {
        if (node.getStateSet()) 
        {
            ++_numInstancedStateSet;
            _statesetSet.insert(node.getStateSet());
        }
    
        ++_numInstancedLOD;
        _lodSet.insert(&node);
        traverse(node);
    }

    void apply(osg::Switch& node)
    {
        if (node.getStateSet()) 
        {
            ++_numInstancedStateSet;
            _statesetSet.insert(node.getStateSet());
        }
    
        ++_numInstancedSwitch;
        _switchSet.insert(&node);
        traverse(node);
    }

    void apply(osg::Geode& node)
    {
        if (node.getStateSet()) 
        {
            ++_numInstancedStateSet;
            _statesetSet.insert(node.getStateSet());
        }
    
        ++_numInstancedGeode;
        _geodeSet.insert(&node);
        
        for(unsigned int i=0; i<node.getNumDrawables();++i)
        {
            apply(*node.getDrawable(i));
        }
        
        traverse(node);
    }

    void apply(osg::Drawable& drawable)
    {
        if (drawable.getStateSet()) 
        {
            ++_numInstancedStateSet;
            _statesetSet.insert(drawable.getStateSet());
        }

        ++_numInstancedDrawable;
        
        drawable.accept(_instancedStats);
    
        _drawableSet.insert(&drawable);
        
        osg::Geometry* geometry = dynamic_cast<osg::Geometry*>(&drawable);
        if (geometry)
        {
            ++_numInstancedGeometry;
            _geometrySet.insert(geometry);
        }
    }

    void totalUpStats()
    {
        _uniqueStats.reset();
        
        for(DrawableSet::iterator itr = _drawableSet.begin();
            itr != _drawableSet.end();
            ++itr)
        {
            (*itr)->accept(_uniqueStats);
        }
    }
    
    void print(std::ostream& out)
    {
    
        unsigned int unique_primitives = 0;
        osgUtil::Statistics::PrimitiveCountMap::iterator pcmitr;
        for(pcmitr = _uniqueStats.GetPrimitivesBegin();
            pcmitr != _uniqueStats.GetPrimitivesEnd();
            ++pcmitr)
        {
            unique_primitives += pcmitr->second;
        }

        unsigned int instanced_primitives = 0;
        for(pcmitr = _instancedStats.GetPrimitivesBegin();
            pcmitr != _instancedStats.GetPrimitivesEnd();
            ++pcmitr)
        {
            instanced_primitives += pcmitr->second;
        }

        out<<"Object Type\t#Unique\t#Instanced"<<std::endl;
        out<<"StateSet      \t"<<_statesetSet.size()<<"\t"<<_numInstancedStateSet<<std::endl;
        out<<"Group      \t"<<_groupSet.size()<<"\t"<<_numInstancedGroup<<std::endl;
        out<<"Transform  \t"<<_transformSet.size()<<"\t"<<_numInstancedTransform<<std::endl;
        out<<"LOD        \t"<<_lodSet.size()<<"\t"<<_numInstancedLOD<<std::endl;
        out<<"Switch     \t"<<_switchSet.size()<<"\t"<<_numInstancedSwitch<<std::endl;
        out<<"Geode      \t"<<_geodeSet.size()<<"\t"<<_numInstancedGeode<<std::endl;
        out<<"Drawable   \t"<<_drawableSet.size()<<"\t"<<_numInstancedDrawable<<std::endl;
        out<<"Geometry   \t"<<_geometrySet.size()<<"\t"<<_numInstancedGeometry<<std::endl;
        out<<"Vertices   \t"<<_uniqueStats._vertexCount<<"\t"<<_instancedStats._vertexCount<<std::endl;
        out<<"Primitives \t"<<unique_primitives<<"\t"<<instanced_primitives<<std::endl;
    }
    
    unsigned int _numInstancedGroup;
    unsigned int _numInstancedSwitch;
    unsigned int _numInstancedLOD;
    unsigned int _numInstancedTransform;
    unsigned int _numInstancedGeode;
    unsigned int _numInstancedDrawable;
    unsigned int _numInstancedGeometry;
    unsigned int _numInstancedStateSet;

    NodeSet _groupSet;
    NodeSet _transformSet;
    NodeSet _lodSet;
    NodeSet _switchSet;
    NodeSet _geodeSet;
    DrawableSet _drawableSet;
    DrawableSet _geometrySet;
    StateSetSet _statesetSet;

    osgUtil::Statistics _uniqueStats;
    osgUtil::Statistics _instancedStats;
};

}

#endif
