/*
 *  Copyright 2001-2005 Internet2
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* SAMLConfig.cpp - SAML runtime configuration

   Scott Cantor
   2/20/02

   $History:$
*/

#include "internal.h"
#include "version.h"

#ifdef HAVE_DLFCN_H
# include <dlfcn.h>
#endif

#if defined(HAVE_LOG4SHIB)
# include <log4shib/PropertyConfigurator.hh>
#elif defined(HAVE_LOG4CPP)
# include <log4cpp/PropertyConfigurator.hh>
#else
# error "Supported logging library not available."
#endif

#include <curl/curl.h>
#include <xercesc/util/PlatformUtils.hpp>
#include <xsec/framework/XSECProvider.hpp>

#include <fstream>

using namespace saml::logging;
using namespace saml;
using namespace std;

SAML_EXCEPTION_FACTORY(MalformedException);
SAML_EXCEPTION_FACTORY(UnsupportedExtensionException);
SAML_EXCEPTION_FACTORY(InvalidCryptoException);
SAML_EXCEPTION_FACTORY(TrustException);

SAML_EXCEPTION_FACTORY(BindingException);
SAML_EXCEPTION_FACTORY(SOAPException);
SAML_EXCEPTION_FACTORY(HTTPException);
SAML_EXCEPTION_FACTORY(ContentTypeException);
SAML_EXCEPTION_FACTORY(UnknownAssertionException);

SAML_EXCEPTION_FACTORY(ProfileException);
SAML_EXCEPTION_FACTORY(FatalProfileException);
SAML_EXCEPTION_FACTORY(RetryableProfileException);
SAML_EXCEPTION_FACTORY(UnsupportedProfileException);
SAML_EXCEPTION_FACTORY(ExpiredAssertionException);
SAML_EXCEPTION_FACTORY(InvalidAssertionException);
SAML_EXCEPTION_FACTORY(ReplayedAssertionException);

extern "C" SAMLQuery* SAMLAttributeQueryFactory(DOMElement* e)
{
    return new SAMLAttributeQuery(e);
}

extern "C" SAMLStatement* SAMLAttributeStatementFactory(DOMElement* e)
{
    return new SAMLAttributeStatement(e);
}

extern "C" SAMLQuery* SAMLAuthenticationQueryFactory(DOMElement* e)
{
    return new SAMLAuthenticationQuery(e);
}

extern "C" SAMLStatement* SAMLAuthenticationStatementFactory(DOMElement* e)
{
    return new SAMLAuthenticationStatement(e);
}

extern "C" SAMLQuery* SAMLAuthorizationDecisionQueryFactory(DOMElement *e)
{
  return new SAMLAuthorizationDecisionQuery(e);
}

extern "C" SAMLStatement* SAMLAuthorizationDecisionStatementFactory(DOMElement *e)
{
  return new SAMLAuthorizationDecisionStatement(e);
}

extern "C" SAMLCondition* SAMLAudienceConditionFactory(DOMElement* e)
{
    return new SAMLAudienceRestrictionCondition(e);
}

void PlugManager::regFactory(const char* type, Factory* factory)
{
    if (type && factory)
        m_map[type]=factory;
}

void PlugManager::regFactory(const char* type, XMLChFactory* factory)
{
    if (type && factory)
        m_XMLCh_map[type]=factory;
}

IPlugIn* PlugManager::newPlugin(const char* type, const DOMElement* e)
{
    FactoryMap::const_iterator i=m_map.find(type);
    if (i==m_map.end())
        throw UnsupportedExtensionException(string("unable to build plugin of type '") + type + "'");
    return i->second(e);
}

IPlugIn* PlugManager::newPlugin(const char* type, const XMLCh* qualifier, const DOMElement* e)
{
    XMLChFactoryMap::const_iterator i=m_XMLCh_map.find(type);
    if (i==m_XMLCh_map.end())
        throw UnsupportedExtensionException(string("unable to build plugin of type '") + type + "'");
    return i->second(qualifier, e);
}

void PlugManager::unregFactory(const char* type)
{
    if (type) {
        m_map.erase(type);
        m_XMLCh_map.erase(type);
    }
}

PlugManager::Factory MemoryReplayCacheFactory;
PlugManager::Factory BrowserProfileFactory;
PlugManager::XMLChFactory SOAPBindingFactory;

extern "C" {
    SAMLArtifactFactory SAMLArtifactType0001Factory;
    SAMLArtifactFactory SAMLArtifactType0002Factory;
}

namespace {
   SAMLInternalConfig g_config;
}

SAMLConfig& SAMLConfig::getConfig()
{
    return g_config;
}

static const char defaultAlgConfig[] =
    "<AlgorithmConfig xmlns=\"http://www.opensaml.org\">"
        "<Signature uri=\"http://www.w3.org/2000/09/xmldsig#dsa-sha1\" num=\"1\"/>"
        "<Signature uri=\"http://www.w3.org/2000/09/xmldsig#hmac-sha1\" num=\"2\"/>"
        "<Signature uri=\"http://www.w3.org/2000/09/xmldsig#rsa-sha1\" num=\"3\"/>"
        "<Digest uri=\"http://www.w3.org/2000/09/xmldsig#sha1\" num=\"1\"/>"
        "<Digest uri=\"http://www.w3.org/2001/04/xmldsig-more#md5\" num=\"2\"/>"
#if (XSEC_VERSION_MAJOR > 1) || (((XSEC_VERSION_MAJOR == 1)) && (XSEC_VERSION_MEDIUM >= 2))
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#hmac-sha224\" num=\"2\"/>"
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#hmac-sha256\" num=\"2\"/>"
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#hmac-sha384\" num=\"2\"/>"
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#hmac-sha512\" num=\"2\"/>"
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#rsa-sha224\" num=\"3\"/>"
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256\" num=\"3\"/>"
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#rsa-sha384\" num=\"3\"/>"
        "<Signature uri=\"http://www.w3.org/2001/04/xmldsig-more#rsa-sha512\" num=\"3\"/>"
        "<Signature uri=\"http://www.w3.org/2000/09/xmldsig#rsa-md5\" num=\"3\"/>"
        "<Digest uri=\"http://www.w3.org/2000/09/xmldsig-more#sha224\" num=\"3\"/>"
        "<Digest uri=\"http://www.w3.org/2001/04/xmlenc#sha256\" num=\"4\"/>"
        "<Digest uri=\"http://www.w3.org/2000/09/xmldsig-more#sha384\" num=\"5\"/>"
        "<Digest uri=\"http://www.w3.org/2001/04/xmlenc#sha512\" num=\"6\"/>"
#endif
    "</AlgorithmConfig>";

bool SAMLInternalConfig::init()
{
    try {
        if (!log_config.empty())
            PropertyConfigurator::configure(log_config);
        m_log=&(Category::getInstance(SAML_LOGCAT".SAMLInternalConfig"));
        m_log->debug("library initialization started");

        if (curl_global_init(CURL_GLOBAL_ALL))
        {
            m_log->fatal("init: failed to initialize libcurl, SSL, or Winsock");
            return false;
        }
        m_log->debug("libcurl %s initialization complete",LIBCURL_VERSION);

        XMLPlatformUtils::Initialize();
        m_log->debug("Xerces %s initialization complete",XERCES_FULLVERSIONDOT);

        soap_pool_init();

        XSECPlatformUtils::Initialise();
        m_xsec=new XSECProvider();
        m_log->debug("XML-Security %s initialization complete",XSEC_FULLVERSIONDOT);
        
        if (schema_dir[schema_dir.length()-1]!='/')
            schema_dir+='/';
        wide_schema_dir=XMLString::transcode(schema_dir.c_str());

        if (inclusive_namespace_prefixes.empty())
            inclusive_namespace_prefixes="#default saml samlp ds xsd xsi code kind rw typens";

        wide_inclusive_namespace_prefixes=XMLString::transcode(inclusive_namespace_prefixes.c_str());

        m_pool=new XML::ParserPool();
        m_pool->registerSchema(XML::SAML_NS,XML::SAML11_SCHEMA_ID);
        m_pool->registerSchema(XML::SAMLP_NS,XML::SAMLP11_SCHEMA_ID);
        m_pool->registerSchema(XML::SOAP11ENV_NS,XML::SOAP11ENV_SCHEMA_ID);
        m_pool->registerSchema(XML::XMLSIG_NS,XML::XMLSIG_SCHEMA_ID);
        m_pool->registerSchema(XML::XML_NS,XML::XML_SCHEMA_ID);
        
        m_compat_pool=new XML::ParserPool();
        m_compat_pool->registerSchema(XML::SAML_NS,XML::SAML_SCHEMA_ID);
        m_compat_pool->registerSchema(XML::SAMLP_NS,XML::SAMLP_SCHEMA_ID);
        m_compat_pool->registerSchema(XML::SOAP11ENV_NS,XML::SOAP11ENV_SCHEMA_ID);
        m_compat_pool->registerSchema(XML::XMLSIG_NS,XML::XMLSIG_SCHEMA_ID);
        m_compat_pool->registerSchema(XML::XML_NS,XML::XML_SCHEMA_ID);
        
        m_log->debug("SAML schema registration complete");

        m_lock=XMLPlatformUtils::makeMutex();

        // Register built-in SAML type factories.
        saml::QName q1(XML::SAMLP_NS,L(AttributeQueryType));
        saml::QName q2(XML::SAMLP_NS,L(AttributeQuery));
        SAMLQuery::regFactory(q1,&SAMLAttributeQueryFactory);
        SAMLQuery::regFactory(q2,&SAMLAttributeQueryFactory);

        saml::QName s1(XML::SAML_NS,L(AttributeStatementType));
        saml::QName s2(XML::SAML_NS,L(AttributeStatement));
        SAMLStatement::regFactory(s1,&SAMLAttributeStatementFactory);
        SAMLStatement::regFactory(s2,&SAMLAttributeStatementFactory);

        saml::QName q3(XML::SAMLP_NS,L(AuthenticationQueryType));
        saml::QName q4(XML::SAMLP_NS,L(AuthenticationQuery));
        SAMLQuery::regFactory(q3,&SAMLAuthenticationQueryFactory);
        SAMLQuery::regFactory(q4,&SAMLAuthenticationQueryFactory);

        saml::QName s3(XML::SAML_NS,L(AuthenticationStatementType));
        saml::QName s4(XML::SAML_NS,L(AuthenticationStatement));
        SAMLStatement::regFactory(s3,&SAMLAuthenticationStatementFactory);
        SAMLStatement::regFactory(s4,&SAMLAuthenticationStatementFactory);

        QName q5(XML::SAMLP_NS,L(AuthorizationDecisionQueryType));
        QName q6(XML::SAMLP_NS,L(AuthorizationDecisionQuery));
        SAMLQuery::regFactory(q5,&SAMLAuthorizationDecisionQueryFactory);
        SAMLQuery::regFactory(q6,&SAMLAuthorizationDecisionQueryFactory);

        QName s5(XML::SAML_NS,L(AuthorizationDecisionStatementType));
        QName s6(XML::SAML_NS,L(AuthorizationDecisionStatement));
        SAMLStatement::regFactory(s5,&SAMLAuthorizationDecisionStatementFactory);
        SAMLStatement::regFactory(s6,&SAMLAuthorizationDecisionStatementFactory);

        saml::QName c1(XML::SAML_NS,L(AudienceRestrictionConditionType));
        saml::QName c2(XML::SAML_NS,L(AudienceRestrictionCondition));
        SAMLCondition::regFactory(c1,&SAMLAudienceConditionFactory);
        SAMLCondition::regFactory(c2,&SAMLAudienceConditionFactory);

        REGISTER_EXCEPTION_FACTORY(MalformedException);
        REGISTER_EXCEPTION_FACTORY(UnsupportedExtensionException);
        REGISTER_EXCEPTION_FACTORY(InvalidCryptoException);
        REGISTER_EXCEPTION_FACTORY(TrustException);
        
        REGISTER_EXCEPTION_FACTORY(BindingException);
        REGISTER_EXCEPTION_FACTORY(SOAPException);
        REGISTER_EXCEPTION_FACTORY(HTTPException);
        REGISTER_EXCEPTION_FACTORY(ContentTypeException);
        REGISTER_EXCEPTION_FACTORY(UnknownAssertionException);
        
        REGISTER_EXCEPTION_FACTORY(ProfileException);
        REGISTER_EXCEPTION_FACTORY(FatalProfileException);
        REGISTER_EXCEPTION_FACTORY(RetryableProfileException);
        REGISTER_EXCEPTION_FACTORY(UnsupportedProfileException);
        REGISTER_EXCEPTION_FACTORY(ExpiredAssertionException);
        REGISTER_EXCEPTION_FACTORY(InvalidAssertionException);
        REGISTER_EXCEPTION_FACTORY(ReplayedAssertionException);

        m_plugMgr.regFactory(DEFAULT_REPLAYCACHE_PROVIDER,&MemoryReplayCacheFactory);
        m_plugMgr.regFactory(DEFAULT_BROWSERPROFILE_PROVIDER,&BrowserProfileFactory);

        setDefaultBindingProvider(SAMLBinding::SOAP,DEFAULT_SOAPBINDING_PROVIDER);
        m_plugMgr.regFactory(DEFAULT_SOAPBINDING_PROVIDER,&SOAPBindingFactory);

        string typecode;
        typecode+=(char)0x0;
        typecode+=(char)0x1;
        SAMLArtifact::regFactory(typecode,&SAMLArtifactType0001Factory);
        typecode[1]=(char)0x2;
        SAMLArtifact::regFactory(typecode,&SAMLArtifactType0002Factory);
        
        m_log->debug("SAML type factory registration complete");
        
        // Load XMLSig algorithm mappings.
        DOMDocument* algConfigDoc=NULL;
        DOMBuilder* parser=m_pool->get(false);
        if (alg_config.empty()) {
            m_log->debug("loading default XMLSig algorithm configuration");
            istringstream in(defaultAlgConfig);
            XML::StreamInputSource src(in);
            Wrapper4InputSource dsrc(&src,false);
            try {
                algConfigDoc=parser->parse(dsrc);
            }
            catch (...) {
                m_log->error("caught exception while parsing built-in XMLSig algorithm configuration");
                parser->release();
            }
        }
        else {
            m_log->debug("loading external XMLSig algorithm configuration from file (%s)", alg_config.c_str());
            ifstream in(alg_config.c_str());
            XML::StreamInputSource src(in);
            Wrapper4InputSource dsrc(&src,false);
            try {
                algConfigDoc=parser->parse(dsrc);
            }
            catch (...) {
                m_log->error("caught exception while parsing XMLSig algorithm configuration in file (%s)", alg_config.c_str());
                parser->release();
            }
        }
        if (algConfigDoc) {
            static const XMLCh uri[]={chLatin_u, chLatin_r, chLatin_i, chNull};
            static const XMLCh num[]={chLatin_n, chLatin_u, chLatin_m, chNull};
            DOMNodeList* nlist=algConfigDoc->getDocumentElement()->getElementsByTagNameNS(XML::OPENSAML_NS,L(Signature));
            unsigned int algs;
            for (algs=0; nlist && algs < nlist->getLength(); algs++) {
                auto_ptr_char u(static_cast<DOMElement*>(nlist->item(algs))->getAttributeNS(NULL,uri));
                auto_ptr_char n(static_cast<DOMElement*>(nlist->item(algs))->getAttributeNS(NULL,num));
                m_log->debug("loaded Signature mapping (%s) -> (%s)", u.get(), n.get());
                m_sigAlgFromURI.insert(pair<string,signatureMethod>(u.get(),static_cast<signatureMethod>(atoi(n.get()))));
            }
            nlist=algConfigDoc->getDocumentElement()->getElementsByTagNameNS(XML::OPENSAML_NS,L(Digest));
            for (algs=0; nlist && algs < nlist->getLength(); algs++) {
                auto_ptr_char u(static_cast<DOMElement*>(nlist->item(algs))->getAttributeNS(NULL,uri));
                auto_ptr_char n(static_cast<DOMElement*>(nlist->item(algs))->getAttributeNS(NULL,num));
                m_log->debug("loaded Digest mapping (%s) -> (%s)", u.get(), n.get());
                m_digestAlgFromURI.insert(pair<string,hashMethod>(u.get(),static_cast<hashMethod>(atoi(n.get()))));
            }
            algConfigDoc->release();
        }
    }
    catch(const ConfigureFailure& e) {
        cerr << "SAMLConfig::init() caught exception while initializing log4cpp: " << e.what() << endl;
        return false;
    }
    catch (const XMLException&) {
        m_log->fatal("caught exception while initializing Xerces");
        curl_global_cleanup();
        return false;
    }

    m_log->info("OpenSAML %s library initialization complete", OPENSAML_FULLVERSIONDOT);
    return true;
}

void SAMLInternalConfig::term()
{
    m_plugMgr.unregFactory("org.opensaml.provider.MemoryReplayCache");
    
    for (vector<void*>::reverse_iterator i=m_libhandles.rbegin(); i!=m_libhandles.rend(); i++)
    {
#if defined(WIN32)
        FARPROC fn=GetProcAddress(static_cast<HMODULE>(*i),"saml_extension_term");
        if (fn)
            fn();
        FreeLibrary(static_cast<HMODULE>(*i));
#elif defined(HAVE_DLFCN_H)
        void (*fn)()=(void (*)())dlsym(*i,"saml_extension_term");
        if (fn)
            fn();
        dlclose(*i);
#else
# error "Don't know about dynamic loading on this platform!"
#endif
    }
    m_libhandles.clear();
    
    delete m_xsec; m_xsec=NULL;
    XSECPlatformUtils::Terminate();
    XMLPlatformUtils::closeMutex(m_lock);
    delete m_pool; m_pool=NULL;
    delete m_compat_pool; m_compat_pool=NULL;
    if (wide_schema_dir) {
        XMLString::release(&wide_schema_dir);
        wide_schema_dir=NULL;
    }
    if (wide_inclusive_namespace_prefixes) {
        XMLString::release(&wide_inclusive_namespace_prefixes);
        wide_inclusive_namespace_prefixes=NULL;
    }
    soap_pool_term();
    XMLPlatformUtils::Terminate();
    curl_global_cleanup();

    m_log->info("OpenSAML %s library shutdown complete", OPENSAML_FULLVERSIONDOT);
    m_log=NULL;
}

void SAMLInternalConfig::saml_lock() const
{
    XMLPlatformUtils::lockMutex(m_lock);
}

void SAMLInternalConfig::saml_unlock() const
{
    XMLPlatformUtils::unlockMutex(m_lock);
}

void SAMLInternalConfig::saml_register_extension(const char* path, void* context) const
{
#ifdef _DEBUG
    saml::NDC ndc("saml_register_extension");
#endif
    m_log->info("loading extension: %s",path);

#if defined(WIN32)
    HMODULE handle=NULL;
    char* fixed=const_cast<char*>(path);
    if (strchr(fixed,'/'))
    {
        fixed=strdup(path);
        char* p=fixed;
        while (p=strchr(p,'/'))
            *p='\\';
    }

    UINT em=SetErrorMode(SEM_FAILCRITICALERRORS);
    try
    {
        handle=LoadLibraryEx(fixed,NULL,LOAD_WITH_ALTERED_SEARCH_PATH);
        if (!handle)
             handle=LoadLibraryEx(fixed,NULL,0);
        if (!handle)
            throw SAMLException(string("SAMLConfig::saml_register_extension() unable to load extension library: ") + fixed);
        FARPROC fn=GetProcAddress(handle,"saml_extension_init");
        if (!fn)
            throw SAMLException(string("SAMLConfig::saml_register_extension() unable to locate saml_extension_init entry point: ") + fixed);
        if (reinterpret_cast<int(*)(void*)>(fn)(context)!=0)
            throw SAMLException(string("SAMLConfig::saml_register_extension() detected error in saml_extension_init: ") + fixed);
        if (fixed!=path)
            free(fixed);
        SetErrorMode(em);
    }
    catch(...)
    {
        if (handle)
            FreeLibrary(handle);
        SetErrorMode(em);
        if (fixed!=path)
            free(fixed);
        throw;
    }

#elif defined(HAVE_DLFCN_H)
    void* handle=dlopen(path,RTLD_LAZY);
    if (!handle)
        throw SAMLException(string("SAMLConfig::saml_register_extension unable to load extension library '") + path + "': " + dlerror());
    int (*fn)(void*)=(int (*)(void*))(dlsym(handle,"saml_extension_init"));
    if (!fn)
    {
        dlclose(handle);
        throw SAMLException(string("SAMLConfig::saml_register_extension unable to locate saml_extension_init entry point in '") + path + "': " + (dlerror() ? dlerror() : "unknown error"));
    }
    try
    {
        if (fn(context)!=0)
            throw SAMLException(string("SAMLConfig::saml_register_extension() detected error in saml_extension_init in ") + path);
    }
    catch(...)
    {
        if (handle)
            dlclose(handle);
        throw;
    }
#else
# error "Don't know about dynamic loading on this platform!"
#endif
    m_libhandles.push_back(handle);
    m_log->info("loaded extension: %s",path);
}

const char* SAMLInternalConfig::getDefaultBindingProvider(const XMLCh* binding) const
{
#ifdef HAVE_GOOD_STL
    map<xstring,string>::const_iterator i=m_bindingMap.find(binding);
#else
    auto_ptr_char temp(binding);
    map<string,string>::const_iterator i=m_bindingMap.find(temp.get());
#endif
    return (i==m_bindingMap.end()) ? NULL : i->second.c_str();
}

void SAMLInternalConfig::setDefaultBindingProvider(const XMLCh* binding, const char* type)
{
#ifdef HAVE_GOOD_STL
    m_bindingMap[binding]=type;
#else
    auto_ptr_char temp(binding);
    m_bindingMap[temp.get()]=type;
#endif
}
