package com.commit451.gitlab.ssl

import com.commit451.gitlab.api.X509TrustManagerProvider
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.net.ssl.*

/**
 * Allows for custom configurations, such as custom trusted hostnames, custom trusted certificates,
 * and private keys
 */
class CustomTrustManager : X509TrustManager {

    private var trustedCertificate: String? = null
    private var trustedHostname: String? = null
    private var sslSocketFactory: SSLSocketFactory? = null
    private var hostnameVerifier: HostnameVerifier? = null

    fun setTrustedCertificate(trustedCertificate: String) {
        this.trustedCertificate = trustedCertificate
        sslSocketFactory = null
    }

    fun setTrustedHostname(trustedHostname: String) {
        this.trustedHostname = trustedHostname
        hostnameVerifier = null
    }

    @Throws(CertificateException::class)
    override fun checkClientTrusted(chain: Array<X509Certificate>, authType: String) {
        X509TrustManagerProvider.x509TrustManager.checkClientTrusted(chain, authType)
    }

    @Throws(CertificateException::class)
    override fun checkServerTrusted(chain: Array<X509Certificate>, authType: String) {
        val cause: CertificateException
        try {
            X509TrustManagerProvider.x509TrustManager.checkServerTrusted(chain, authType)
            return
        } catch (e: CertificateException) {
            cause = e
        }

        if (trustedCertificate != null && trustedCertificate == X509Util.getFingerPrint(chain[0])) {
            return
        }

        throw X509CertificateException(cause.message!!, cause, chain)
    }

    override fun getAcceptedIssuers(): Array<X509Certificate> {
        return X509TrustManagerProvider.x509TrustManager.acceptedIssuers
    }

    fun getSSLSocketFactory(): SSLSocketFactory {
        if (sslSocketFactory != null) {
            return sslSocketFactory!!
        }

        val keyManagers: Array<KeyManager>? = null

        try {
            val sslContext = SSLContext.getInstance("TLS")
            sslContext.init(keyManagers, arrayOf<TrustManager>(this), null)
            sslSocketFactory = CustomSSLSocketFactory(sslContext.socketFactory)
        } catch (e: Exception) {
            throw IllegalStateException(e)
        }

        return sslSocketFactory!!
    }

    fun getHostnameVerifier(): HostnameVerifier {
        if (hostnameVerifier == null) {
            hostnameVerifier = CustomHostnameVerifier(trustedHostname)
        }

        return hostnameVerifier!!
    }
}
