package freenet.session;

import java.math.BigInteger;
import freenet.crypt.*;
import freenet.crypt.ciphers.Rijndael;
import freenet.*;
import freenet.support.Logger;
import freenet.support.io.SafeBufferedInputStream;
import java.io.*;

class FnpLink implements LinkConstants, Link {
    
    protected static final int 
        AUTH_LAYER_VERSION    = 0x01,
        VER_BIT_LENGTH        = 5,
        VER_BIT_MASK          = 0x1f,
        NEGOTIATION_MODE_MASK = 0x03,
        RESTART               = 0x01,
        AUTHENTICATE          = 0x00,
        RESTART_BYTE          = (AUTH_LAYER_VERSION << (8-VER_BIT_LENGTH)) + RESTART,
        AUTHENTICATE_BYTE     = (AUTH_LAYER_VERSION << (8-VER_BIT_LENGTH)) + AUTHENTICATE,
        SILENT_BOB_BYTE       = 0xfb,
        SILENT_BOB_HANGUP     = 0xfc;
    
        
    protected SafeBufferedInputStream in;
    protected BufferedOutputStream out;
    protected Connection conn;
    protected boolean ready = false;
    protected FnpLinkToken linkInfo;
    protected FnpLinkManager linkManager;
    protected DLES asymCipher = new DLES();

    protected FnpLink(FnpLinkManager flm, Connection c) {
        linkManager = flm;
        conn = c;
    }

    
    protected FnpLink(FnpLinkManager flm, FnpLinkToken linkInfo,
                      CipherInputStream in, CipherOutputStream out, 
                      Connection conn) {
        this(flm, conn);
        this.linkInfo = linkInfo;
        setInputStream(in);
        setOutputStream(out);
        ready = true;
    }
    
    protected void accept(DSAPrivateKey privMe,
                          DSAIdentity pubMe, 
                          int paravect) throws CommunicationException {
        synchronized (conn) {
            try {
                int oldTimeout = conn.getSoTimeout();
                try {
                    conn.setSoTimeout(Core.authTimeout);
                    
                    OutputStream rawOut = conn.getOut();
                    InputStream rawIn = conn.getIn();        
                    int connType = rawIn.read();
                    
                    if (((connType >> (8-VER_BIT_LENGTH)) & VER_BIT_MASK) !=
                        AUTH_LAYER_VERSION)
                        throw new NegotiationFailedException(conn.getPeerAddress(),
                                                             "Wrong auth protocol version");
               
                    int negmode = connType & NEGOTIATION_MODE_MASK;
    
                    if (negmode == RESTART) {
                        Core.logger.log(this, "Accepting restart.", 
                                        Core.logger.DEBUG);
                        boolean worked = receiveRestartRequest(privMe, pubMe, paravect);
                        if (worked)
                            return;
                        else {
                            // reread
                            connType = rawIn.read();
                            negmode = connType & NEGOTIATION_MODE_MASK;
                        }
                    }
    
                    if (negmode == AUTHENTICATE) {
                        Core.logger.log(this, "Accepting full negotiation",
                                        Core.logger.DEBUG);
                        negotiateInbound(privMe, pubMe, paravect);
                    } else {
                        throw new NegotiationFailedException(conn.getPeerAddress(),
                                                             "Invalid authentication mode");
                    }
                }
                finally {
                   conn.setSoTimeout(oldTimeout);
                }
            }
            catch (InterruptedIOException e) {
                throw new ConnectFailedException(conn.getPeerAddress(),
                                                 "authentication timed out");
            }
            catch (IOException e) {
                String s = "I/O error during inbound auth: "+e;
                Core.logger.log(this, s, Logger.MINOR);
                throw new ConnectFailedException(conn.getPeerAddress(), s);
            }
        }
    }

    protected void solicit(DSAAuthentity privMe, DSAIdentity pubMe,
                           DSAIdentity bob, boolean safe)
                                            throws CommunicationException {
        synchronized (conn) {
            try {
                int oldTimeout = conn.getSoTimeout();
                conn.setSoTimeout(Core.authTimeout);
                try {
                    linkInfo = linkManager.searchOutboundLinks(bob);
                    if (linkInfo != null) {
                        Core.logger.log(this, "Soliciting restart", Logger.DEBUG);
                        try {
                            boolean worked = 
                                negotiateRestart(bob, linkInfo.getKeyHash(),
                                                 linkInfo.getKey(), safe);
                            if (worked)
                                return;
                            else
                                linkManager.removeLink(linkInfo);
                        }
                        catch (AuthenticationFailedException e) {
                            linkManager.removeLink(linkInfo);
                            throw (AuthenticationFailedException) e.fillInStackTrace();
                        }
                    }
                    Core.logger.log(this, "Soliciting full negotiation", Logger.DEBUG);
                    negotiateOutbound(privMe, pubMe, bob);
                }
                finally {
                    conn.setSoTimeout(oldTimeout);
                }
            }
            catch (InterruptedIOException e) {
                throw new ConnectFailedException(conn.getPeerAddress(),
                                                 "authentication timed out");
            }
            catch (IOException e) {
                String s = "I/O error during outbound auth: "+e;
                Core.logger.log(this, s, Logger.MINOR);
                throw new ConnectFailedException(conn.getPeerAddress(), s);
            }
        }
    }

    private boolean negotiateRestart(DSAIdentity bob, BigInteger hk,
                                     byte[] k, boolean safe)
                            throws CommunicationException, IOException {
        
        OutputStream rawOut = conn.getOut();
        InputStream rawIn = conn.getIn();
        
        // Send restart challenge
        //System.err.println(hk.toString(16));
        BigInteger M = hk.shiftLeft(8);
        if (safe)
            M = M.setBit(0);
        
        BigInteger[] C = asymCipher.encrypt(bob, M, Core.randSource);
        rawOut.write(RESTART_BYTE);
        Util.writeMPI(C[0], rawOut);
        Util.writeMPI(C[1], rawOut);
        Util.writeMPI(C[2], rawOut);
        //rawOut.flush(); Flushing is a pass on some level...
        // Set up outbound link level encryption
        BlockCipher c = new Rijndael();
        c.initialize(k);
        PCFBMode ctx = new PCFBMode(c);
        
        // Set up inbound link level encryption
        if (safe) {
            rawOut.flush();
            int cb = rawIn.read();
            if (cb == SILENT_BOB_BYTE) {
                ctx.writeIV(Core.randSource, rawOut);
                setOutputStream(ctx, rawOut);
                setInputStream(c, rawIn);
                conn.notifyAll();
                ready = true;
                return true;
            } else if (cb == SILENT_BOB_HANGUP) {
                return false;
            } else {
                throw new NegotiationFailedException(conn.getPeerAddress(),
                                                     "Bad OK byte");
            }
        } else {
            ctx.writeIV(Core.randSource, rawOut);
            setOutputStream(ctx, rawOut);
            setSilentBobCheckingInputStream(c, rawIn);
            conn.notifyAll();
            ready = true;
            return true;
        }
    }


    private boolean receiveRestartRequest(DSAPrivateKey privMe,
                                          DSAIdentity pubMe, 
                                          int paravect) 
                                throws CommunicationException, IOException {

        OutputStream rawOut = conn.getOut();
        InputStream rawIn = conn.getIn();                
    
        BigInteger[] C = new BigInteger[3];
        C[0] = Util.readMPI(rawIn);
        C[1] = Util.readMPI(rawIn);
        C[2] = Util.readMPI(rawIn);
         
        BigInteger P = null;
        try {
            P = asymCipher.decrypt(pubMe.getGroup(), privMe, C);
        } catch (DecryptionFailedException dfe) {
            throw new AuthenticationFailedException(conn.getPeerAddress(),
                      "Invalid restart message (MAC verify must have failed)");
        }
                
        boolean safe;
        if (P.byteValue() == 0)
            safe = false;
        else if (P.byteValue() == 1)
            safe = true;
        else 
            throw new AuthenticationFailedException(conn.getPeerAddress(),
                        "Invalid restart message (low 8 bits not 0 or 1)");
        
        P = P.shiftRight(8);
        //System.err.println(P.toString(16));
        //System.err.println("LALA:" + P.hashCode());
        linkInfo = linkManager.searchInboundLinks(P);
        if (linkInfo == null) {
            rawOut.write(SILENT_BOB_HANGUP);
            rawOut.flush();
            if (safe)
                return false;
            else
                throw new AuthenticationFailedException(conn.getPeerAddress(),
                    "Unknown Link trying to restart, and unable to do fallback.");
        }

        // oh glorious 0xfb
        rawOut.write(SILENT_BOB_BYTE);
            
         // Set up outbound link level encryption
        BlockCipher c = new Rijndael();
        c.initialize(linkInfo.getKey());
        PCFBMode ctx = new PCFBMode(c);
        ctx.writeIV(Core.randSource, rawOut);
        setOutputStream(ctx, rawOut);
        
        // Set up inbound link level encryption
        setInputStream(c, rawIn);
        
        conn.notifyAll();
        ready = true;
        return true;
    }

    private void negotiateInbound(DSAPrivateKey privMe,
                                  DSAIdentity pubMe, 
                                  int paravect) throws CommunicationException,
                                                       IOException {
        boolean cbSent = false;
        BigInteger[] dhParams = DiffieHellman.getParams(),
        DLESCa = new BigInteger[3];
        
        DSAIdentity Ya;
        
        BigInteger Ca, Cb = dhParams[1], Z, R = dhParams[0];
        
        BlockCipher c = new Rijndael();
        byte[] k = new byte[c.getKeySize()>>3];

        OutputStream rawOut = conn.getOut();
        InputStream rawIn = conn.getIn();

        if ((paravect & SILENT_BOB) == 0) {
            rawOut.write(SILENT_BOB_BYTE);
            Util.writeMPI(Cb, rawOut);
            cbSent = true;
            rawOut.flush();
        }
        Ca = Util.readMPI(rawIn);
        
        Z = Ca.modPow(R, DiffieHellman.getGroup().getP());
        byte[] kent = Util.MPIbytes(Z);
        Util.makeKey(kent, k, 0, k.length);
        c.initialize(k);
   
        DLESCa[0] = Util.readMPI(rawIn);            
        DLESCa[1] = Util.readMPI(rawIn);
        DLESCa[2] = Util.readMPI(rawIn);
   
        if ((paravect & VERIFY_BOBKNOWN) != 0) {
            BigInteger Cav = null;
            try {
                Cav = asymCipher.decrypt(pubMe.getGroup(),
                                         privMe, DLESCa);
            } catch (DecryptionFailedException dfe) {
                throw new AuthenticationFailedException(conn.getPeerAddress(),
                                    "Remote sent bogus DLES encrypted data");
            }
            if (!Cav.equals(Ca)) {
                throw new AuthenticationFailedException(conn.getPeerAddress(),
                                    "Remote does not know my identity");
            }
        }
        
        if (!cbSent) {
            rawOut.write(SILENT_BOB_BYTE);
            Util.writeMPI(Cb, rawOut);
            cbSent = true;
            rawOut.flush();
        }
        
        PCFBMode pcfb = new PCFBMode(c);
        pcfb.writeIV(Core.randSource, rawOut);
        setOutputStream(pcfb, rawOut);
        setInputStream(c, rawIn);
        
        SHA1 ctx = new SHA1();
        byte[] Cabytes = Util.MPIbytes(Ca);
        byte[] Cbbytes = Util.MPIbytes(Cb);
        ctx.update(Cabytes, 0, Cabytes.length);
        ctx.update(Cbbytes, 0, Cbbytes.length);
        BigInteger M = Util.byteArrayToMPI(ctx.digest());
        DSASignature sigCaCb = DSA.sign(pubMe.getGroup(), privMe, 
                                        M, Core.randSource);
        
        sigCaCb.write(out);
        out.flush();

        Ya = (DSAIdentity) DSAIdentity.read(in);
        byte[] Yabytes = Ya.asBytes();
        //System.err.println(freenet.support.Fields.bytesToHex(Yabytes));
        //System.err.println(Ya.toString());
        ctx.update(Yabytes, 0, Yabytes.length);
        ctx.update(Cabytes, 0, Cabytes.length);
        ctx.update(Cbbytes, 0, Cbbytes.length);
        M = Util.byteArrayToMPI(ctx.digest());
        DSASignature sigYaCaCb = DSASignature.read(in);
        if (!DSA.verify(Ya, sigYaCaCb, M)) {
            throw new AuthenticationFailedException(conn.getPeerAddress(),
                "Remote does not posess the private key to the public key it offered");
        }
        
        linkInfo = (FnpLinkToken) linkManager.addLink(Ya, pubMe, k);
        
        conn.notifyAll();
        ready = true;
    }
    
    private void negotiateOutbound(DSAPrivateKey privMe, 
                                   DSAIdentity pubMe,
                                   DSAIdentity bob) throws CommunicationException,
                                                           IOException {

        BigInteger[] DLESCa, dhParams = DiffieHellman.getParams();
        BigInteger Ca = dhParams[1], Cb, Z, R = dhParams[0], M;
        
        BlockCipher c = new Rijndael();
        byte[] k = new byte[c.getKeySize()>>3];
        
        OutputStream rawOut = conn.getOut();
        InputStream rawIn = conn.getIn();
        
        rawOut.write(AUTHENTICATE_BYTE);
        Util.writeMPI(Ca, rawOut);
        
        DLESCa = asymCipher.encrypt(bob, Ca, Core.randSource);
        Util.writeMPI(DLESCa[0], rawOut);
        Util.writeMPI(DLESCa[1], rawOut);
        Util.writeMPI(DLESCa[2], rawOut);
       
        rawOut.flush();

        if (SILENT_BOB_BYTE != rawIn.read())
            throw new NegotiationFailedException(
                conn.getPeerAddress(),
                "Bob was not silent in the way that we like"
            );
        
        Cb = Util.readMPI(rawIn);
        Core.logger.log(this, "Read first MPI from peer",
                        Core.logger.DEBUG);
        //System.err.println("LALA " + Cb.toString(16));
        Z = Cb.modPow(R, DiffieHellman.getGroup().getP());
        byte[] kent = Util.MPIbytes(Z);
        Util.makeKey(kent, k, 0, k.length);
        c.initialize(k);            
        
        PCFBMode pcfb = new PCFBMode(c);
        pcfb.writeIV(Core.randSource, rawOut);
        setOutputStream(pcfb, rawOut);
        setInputStream(c, rawIn);
        //System.err.println("LALA " + pubMe.toString());
        pubMe.writeForWire(out);
        SHA1 ctx = new SHA1();
        byte[] Cabytes = Util.MPIbytes(Ca);
        byte[] Cbbytes = Util.MPIbytes(Cb);
        byte[] Yabytes = pubMe.asBytes();
        ctx.update(Yabytes, 0, Yabytes.length);
        ctx.update(Cabytes, 0, Cabytes.length);
        ctx.update(Cbbytes, 0, Cbbytes.length);
        M = Util.byteArrayToMPI(ctx.digest());
        
        DSASignature sigYaCaCb = DSA.sign(pubMe.getGroup(), privMe, M, 
                                          Core.randSource);            
        sigYaCaCb.write(out);
        out.flush();

        DSASignature sigCaCb = DSASignature.read(in);

        ctx.update(Cabytes, 0, Cabytes.length);
        ctx.update(Cbbytes, 0, Cbbytes.length);
        M = Util.byteArrayToMPI(ctx.digest());
        if (!DSA.verify(bob, sigCaCb, M)) {
            throw new AuthenticationFailedException(conn.getPeerAddress(),
                "Remote is not who she claims to be, or did not receive the correct DH parameters");
        }
        
        linkInfo = (FnpLinkToken) linkManager.addLink(bob, pubMe, k);
        
        conn.notifyAll();
        ready = true;
    }



    public final LinkManager getManager() {
        return linkManager;
    }
    
    public final InputStream getInputStream() {
        waitReady();
        return in;
    }
    
    public final OutputStream getOutputStream() {
        waitReady();
        return out;
    }

    
    private final void setInputStream(CipherInputStream in) {
        this.in = new SafeBufferedInputStream(in, Core.streamBufferSize);
    }

    private final void setInputStream(BlockCipher c, InputStream raw)
                                                throws IOException {
        this.in = new SafeBufferedInputStream(new CipherInputStream(c, raw, 
                                                                    true),
                                              Core.streamBufferSize);
    }

    private final void setSilentBobCheckingInputStream(BlockCipher c,
                                                       InputStream raw)
                                                    throws IOException {
        setInputStream(c, new SilentBobCheckingInputStream(raw));
    }

    /**
     * This lets us check the 0xfb that should precede Bob's encrypted
     * messages after a resume, without adding an extra pass.
     */
    private final class SilentBobCheckingInputStream extends FilterInputStream {
        private boolean checked = false;
        public SilentBobCheckingInputStream(InputStream in) {
            super(in);
        }
        public final int read() throws IOException {
            if (!checked) check();
            return this.in.read();
        }
        public final int read(byte[] buf, int off, int len) throws IOException {
            if (!checked) check();
            return this.in.read(buf, off, len);
        }
        private void check() throws IOException {
            checked = true;
            if (this.in.read() != SILENT_BOB_BYTE) {
                Core.logger.log(FnpLink.this,
                    "Bob didn't send the magic byte after a resume; discarding cached session key",
                    Logger.MINOR);
                FnpLink.this.close();
                linkManager.removeLink(linkInfo);
                throw new IOException("Bob didn't send the magic byte after a resume");
            }
        }
    }

    
    private final void setOutputStream(CipherOutputStream out) {
        this.out = new BufferedOutputStream(out, Core.streamBufferSize);
    }

    private final void setOutputStream(PCFBMode c, OutputStream raw)
                                                throws IOException {
        this.out = new BufferedOutputStream(new CipherOutputStream(c, raw),
                                            Core.streamBufferSize);
    }
    
    
    public void close() throws IOException {
        synchronized (conn) {
            conn.close();
            conn.notifyAll();
            ready = true;
        }

        // Set streams to null, so the buffer
        // can be gc'ed (using the default configuration this frees
        // up 128k)
        this.in = null;
        this.out = null;
    }

    
    public final void setTimeout(int timeout) throws IOException {
        conn.setSoTimeout(timeout);
    }

    public final int getTimeout() throws IOException {
        return conn.getSoTimeout();
    }

    
    public final Identity getMyIdentity() {
        return linkInfo.getMyIdentity();
    }

    public final Address getMyAddress(ListeningAddress lstaddr) 
                                    throws BadAddressException {
        return conn.getMyAddress(lstaddr);
    }
    
    public final Address getMyAddress() {
        return conn.getMyAddress();
    }
    
    
    public final Address getPeerAddress(ListeningAddress lstaddr) 
                                    throws BadAddressException {
        waitReady();
        return conn.getPeerAddress(lstaddr);
    }
    
    public final Address getPeerAddress() {
        waitReady();
        return conn.getPeerAddress();
    }
    
    public final Identity getPeerIdentity() {
        waitReady();
        return linkInfo.getPeerIdentity();
    }

    
    //public boolean sending() {
    //    return true; //FIXME
    //}

    
    private void waitReady() {
        if (!ready) {
            synchronized (conn) {
                while (!ready) {
                    try { conn.wait(); } catch (InterruptedException e) {}
                }
            }
        }
    }
}






