/*
Paros and its related class files.
Paros is an HTTP/HTTPS proxy for assessing web application security.
Copyright (C) 2003-2004 www.proofsecure.com

This program is free software; you can redistribute it and/or
modify it under the terms of the Clarified Artistic License
as published by the Free Software Foundation.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
Clarified Artistic License for more details.

You should have received a copy of the Clarified Artistic License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
*/
package com.proofsecure.paros.network;


import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.KeyStore;

import javax.net.ServerSocketFactory;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;

public class SSLConnector {

	private static final String CRLF = "\r\n";

	private SSLSocketFactory clientSSLSockFactory = null;
	private SSLSocketFactory clientSSLSockCertFactory = null;
	private ServerSocketFactory serverSSLSockFactory = null;


	public SSLConnector() {
	}
		
	public void init() {
		clientSSLSockFactory = getClientSocketFactory("SSL");
		serverSSLSockFactory = getServerSocketFactory("SSL");
	}

	public SSLSocket client(String hostName, int hostPort, boolean useClientCert) throws IOException {

		SSLSocket socket = null;

		socket = clientNoHandshake(hostName, hostPort, useClientCert);

		socket.startHandshake();

		return socket;
	}

	public SSLSocket clientNoHandshake(String hostName, int hostPort, boolean useClientCert) throws IOException {

		SSLSocket socket = null;

		if (useClientCert) {
			socket = (SSLSocket) clientSSLSockCertFactory.createSocket(hostName, hostPort);
		} else {
			socket = (SSLSocket) clientSSLSockFactory.createSocket(hostName, hostPort);
		}

		return socket;
	}

	public SSLSocket clientViaProxy(String hostName, int hostPort, String proxyName, int proxyPort, boolean useClientCert) throws IOException {

	    SSLSocket socket = clientViaProxyNoHandshake(hostName, hostPort, proxyName, proxyPort, useClientCert);

		socket.startHandshake();
		return socket;
	}

	public SSLSocket clientViaProxyNoHandshake(String hostName, int hostPort, String proxyName, int proxyPort, boolean useClientCert) throws IOException {

		HttpResponseHeader res = new HttpResponseHeader();
		Socket tunnel = establishTunnel(hostName, hostPort, proxyName, proxyPort, res);

		if (tunnel == null) {
			return null;
		}
		
		SSLSocket socket = null;
		if (useClientCert) {
	    	socket = (SSLSocket) clientSSLSockCertFactory.createSocket(tunnel, hostName, hostPort, true);
	    } else {
	    	socket = (SSLSocket) clientSSLSockFactory.createSocket(tunnel, hostName, hostPort, true);
	    }	

		return socket;
	}

	public void setClientCert(File keyFile, char[] passPhrase) throws Exception {

	
	    // Set up a key manager for client authentication
	    // if asked by the server.  Use the implementation's
	    // default TrustStore and secureRandom routines.
		//
		clientSSLSockCertFactory = null;
		KeyManager[] keyMgr = null;
		TrustManager[] trustMgr = {new RelaxedX509TrustManager()};	// Trust all invalid server certificate

		SSLContext ctx;
		KeyManagerFactory kmf;
		KeyStore ks;

		ctx = SSLContext.getInstance("SSL");
		kmf = KeyManagerFactory.getInstance("SunX509");
		//ks = KeyStore.getInstance("JKS");
		ks = KeyStore.getInstance("pkcs12");

		ks.load(new FileInputStream(keyFile), passPhrase);
		java.security.SecureRandom x = new java.security.SecureRandom();
		x.setSeed(System.currentTimeMillis());
		kmf.init(ks, passPhrase);
		ctx.init(kmf.getKeyManagers(), trustMgr, x);

		clientSSLSockCertFactory = ctx.getSocketFactory();

	}

	public Socket establishTunnel(String hostName, int hostPort, String proxyName, int proxyPort, HttpResponseHeader res) throws IOException {
		Socket tunnel = new Socket(proxyName, proxyPort);
		HttpInputStream tunnel_in = new HttpInputStream(tunnel.getInputStream());
		HttpOutputStream tunnel_out = new HttpOutputStream(tunnel.getOutputStream());
		HttpRequestHeader req = new HttpRequestHeader(getConnectString(hostName, hostPort));

		tunnel_out.write(req);
		tunnel_out.flush();
		HttpResponseHeader tunnelRes = (HttpResponseHeader) tunnel_in.readHeader();
		res.setMessage(tunnelRes.toString());
		if (res.isMalformedHeader() || res.getStatusCode() != HttpStatusCode.OK) {
			return null;
		}

		return tunnel;

	}

	public ServerSocket listen(int portNum) throws IOException {
		ServerSocket sslServerPort = null;
	    sslServerPort = serverSSLSockFactory.createServerSocket(portNum);
		return sslServerPort;
	}

	public ServerSocket listen() throws IOException {
		ServerSocket sslServerPort = null;
	    sslServerPort = serverSSLSockFactory.createServerSocket();
		return sslServerPort;
	}

	public ServerSocket listen(int portNum, int maxConnection) throws IOException {
		ServerSocket sslServerPort = null;
		sslServerPort = serverSSLSockFactory.createServerSocket(portNum, maxConnection);
		return sslServerPort;
	}

	public ServerSocket listen(int paramPortNum, int maxConnection, InetAddress ip ) throws IOException {

      	ServerSocket sslServerPort = serverSSLSockFactory.createServerSocket(paramPortNum, maxConnection, ip);
		return sslServerPort;
	}


	public SSLSocketFactory getClientSocketFactory(String type) {
		KeyManager[] keyMgr = null;
		TrustManager[] trustMgr = new TrustManager[]{new RelaxedX509TrustManager()};	// Trust all invalid server certificate

		try {
			SSLContext sslContext = SSLContext.getInstance(type);
			java.security.SecureRandom x = new java.security.SecureRandom();
			x.setSeed(System.currentTimeMillis());
			sslContext.init(null, trustMgr, x);
			clientSSLSockFactory = sslContext.getSocketFactory();
			HttpsURLConnection.setDefaultSSLSocketFactory(clientSSLSockFactory);

		} catch (Exception e) {
			e.printStackTrace();
		}

		return clientSSLSockFactory;

	}


    public ServerSocketFactory getServerSocketFactory(String type) {
    	
		if (type.equals("SSL") || type.equals("SSLv3")) {
		    SSLServerSocketFactory ssf = null;
		    try {
				// set up key manager to do server authentication
				SSLContext ctx;
				KeyManagerFactory kmf;
				KeyStore ks;
				char[] passphrase = "!@#$%^&*()".toCharArray();

				ctx = SSLContext.getInstance(type);
				kmf = KeyManagerFactory.getInstance("SunX509");
				ks = KeyStore.getInstance("JKS");

				java.security.SecureRandom x = new java.security.SecureRandom();
				x.setSeed(System.currentTimeMillis());

				ks.load(new FileInputStream("paroskey"), passphrase);
				kmf.init(ks, passphrase);
				ctx.init(kmf.getKeyManagers(), null, x);

				ssf = ctx.getServerSocketFactory();
				return ssf;
		    } catch (Exception e) {
				e.printStackTrace();
		    }
		} else {
		    return ServerSocketFactory.getDefault();
		}
		return null;

	}

	private static String getConnectString(String hostName, int hostPort) {
		StringBuffer sb = new StringBuffer(200);
		sb.append("CONNECT " + hostName + ":" + hostPort + " HTTP/1.0" + CRLF);
		sb.append("User-Agent: Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.0;)" + CRLF);
		sb.append("Host: " + hostName + ":" + hostPort + CRLF);
                    sb.append("Pragma: no-cache" + CRLF);
		sb.append("Content-Length: 0" + CRLF);
		sb.append(CRLF);
		return sb.toString();
	}


}

class RelaxedX509TrustManager implements X509TrustManager {
	public boolean checkClientTrusted(java.security.cert.X509Certificate[] chain){
		return true;
	}

	public boolean isServerTrusted(java.security.cert.X509Certificate[] chain){
		return true;
	}

	public boolean isClientTrusted(java.security.cert.X509Certificate[] chain){
		return true;
	}


	public java.security.cert.X509Certificate[] getAcceptedIssuers() {
		return null;
	}

	public void checkClientTrusted(java.security.cert.X509Certificate[] chain, String authType)
	{}

	public void checkServerTrusted(java.security.cert.X509Certificate[] chain, String authType)
	{}
}

