package com.example.sekreto

import java.nio.ByteBuffer
import java.nio.charset.Charset
import java.security.SecureRandom
import javax.crypto.BadPaddingException
import javax.crypto.Cipher
import javax.crypto.SecretKey
import javax.crypto.SecretKeyFactory
import javax.crypto.spec.GCMParameterSpec
import javax.crypto.spec.PBEKeySpec
import javax.crypto.spec.SecretKeySpec

class Crypto {

    private val TAG_LENGTH = 16
    private val IV_LENGTH = 12
    private val SALT_LENGTH = 16
    private val KEY_LENGTH = 32
    private val ITERATIONS = 65535

    private val CIPHER_ALGORITHM = "AES/GCM/NoPadding"
    private val FACTORY_INSTANCE = "PBKDF2WithHmacSHA1"
    private val UTF_8: Charset = Charsets.UTF_8

    private fun getSecretKey(password: String, salt: ByteArray): SecretKey {
        val spec = PBEKeySpec(password.toCharArray(), salt, ITERATIONS, KEY_LENGTH * 8)

        val factory = SecretKeyFactory.getInstance(FACTORY_INSTANCE)
        return SecretKeySpec(factory.generateSecret(spec).encoded, "AES")
    }

    private fun getRandomNonce(length: Int): ByteArray {
        val nonce = ByteArray(length)
        SecureRandom().nextBytes(nonce)
        return nonce
    }

    private fun initCipher(mode: Int, secretKey: SecretKey, iv: ByteArray): Cipher {
        val cipher = Cipher.getInstance(CIPHER_ALGORITHM)
        cipher.init(mode, secretKey, GCMParameterSpec(TAG_LENGTH * 8, iv))
        return cipher
    }

    private fun byte2hex(byteArray: ByteArray): String {
        val stringBuilder = StringBuilder()
        for (byte in byteArray) {
            val hex = String.format("%02X", byte)
            stringBuilder.append(hex)
        }
        return stringBuilder.toString()
    }

    private fun hex2byte(hexString: String): ByteArray {
        val byteArray = ByteArray(hexString.length / 2)
        for (i in byteArray.indices) {
            val index = i * 2
            val hex = hexString.substring(index, index + 2)
            val byte = hex.toInt(16).toByte()
            byteArray[i] = byte
        }
        return byteArray
    }

    private fun getAESKeyFromPassword(password: CharArray, salt: ByteArray): SecretKey {
        val factory = SecretKeyFactory.getInstance(FACTORY_INSTANCE)
        val spec = PBEKeySpec(password, salt, ITERATIONS, KEY_LENGTH * 8)
        val secret = SecretKeySpec(factory.generateSecret(spec).encoded, "AES")
        return secret
    }

    fun encrypt(password: String, plainMessage: String): String {
        val salt = getRandomNonce(SALT_LENGTH)
        val secretKey = getSecretKey(password, salt)

        val iv = getRandomNonce(IV_LENGTH)

        val cipher = initCipher(Cipher.ENCRYPT_MODE, secretKey, iv)

        val encryptedMessageByte = cipher.doFinal(plainMessage.toByteArray(UTF_8))

        val cipherByte = ByteBuffer.allocate(salt.size + iv.size + encryptedMessageByte.size)
            .put(salt)
            .put(iv)
            .put(encryptedMessageByte)
            .array()

        return byte2hex(cipherByte)
    }

    fun decrypt(cipherContent: String, password: String): String? {
        try {
            val decode = hex2byte(cipherContent)
            val byteBuffer = ByteBuffer.wrap(decode)

            val salt = ByteArray(SALT_LENGTH)
            byteBuffer.get(salt)

            val iv = ByteArray(IV_LENGTH)
            byteBuffer.get(iv)

            val content = ByteArray(byteBuffer.remaining())
            byteBuffer.get(content)

            val cipher = Cipher.getInstance(CIPHER_ALGORITHM)
            val aesKeyFromPassword = getAESKeyFromPassword(password.toCharArray(), salt)
            cipher.init(Cipher.DECRYPT_MODE, aesKeyFromPassword, GCMParameterSpec(TAG_LENGTH * 8, iv))
            val plainText = cipher.doFinal(content)
            return String(plainText, UTF_8)
        } catch (e: BadPaddingException) {
            e.printStackTrace()
            return null
        } catch (e: Exception) {
            e.printStackTrace()
            return null
        }
    }
}