package org.codeberg.quecomet.oshi.network

import android.annotation.SuppressLint
import androidx.compose.runtime.Stable
import org.codeberg.quecomet.oshi.SSLCheckMode
import org.codeberg.quecomet.oshi.data.room.OshiInstance
import java.security.KeyStore
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
import javax.inject.Inject
import javax.inject.Singleton
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager

@SuppressLint("CustomX509TrustManager")
@Singleton
@Stable
class CustomX509TrustManager
@Inject
constructor(
    private val sslCheckMode: SSLCheckMode,
    private val oshiInstance: OshiInstance,
) : X509TrustManager {

  private val defaultTrustManager: X509TrustManager

  init {
    val trustManagerFactory =
        TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
    trustManagerFactory.init(null as KeyStore?) // Initializes with the default keystore
    defaultTrustManager =
        trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager
  }

  override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {
    defaultTrustManager.checkClientTrusted(chain, authType)
  }

  override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
    if (chain == null || chain.isEmpty())
        throw IllegalArgumentException("Certificate chain is empty")

    if (sslCheckMode == SSLCheckMode.FINGERPRINT_ONLY) {
      checkFingerprint(chain[0])
    } else {
      // it will need certificate check for two other modes
      defaultTrustManager.checkServerTrusted(chain, authType)
      if (sslCheckMode == SSLCheckMode.CERTIFICATE_AND_FINGERPRINT) {
        checkFingerprint(chain[0])
      }
    }
  }

  override fun getAcceptedIssuers(): Array<X509Certificate> {
    return defaultTrustManager.acceptedIssuers
  }

  private fun checkFingerprint(serverCert: X509Certificate) {
    val sha256Fingerprint = serverCert.getSha256Fingerprint()
    val sha1Fingerprint = serverCert.getSha1Fingerprint()

    if (!oshiInstance.fingerprintConfirmedByUser ||
        oshiInstance.sha256Fingerprint.isNullOrBlank() ||
        oshiInstance.sha1Fingerprint.isNullOrBlank() ||
        sha256Fingerprint != oshiInstance.sha256Fingerprint ||
        sha1Fingerprint != oshiInstance.sha1Fingerprint) {
      throw CertificateException("Certificate fingerprint mismatch")
    }
  }
}
