package com.ismartcoding.plain.services.webrtc

import android.annotation.SuppressLint
import android.content.Context
import android.graphics.Point
import android.hardware.display.DisplayManager
import android.hardware.display.VirtualDisplay
import android.media.AudioAttributes
import android.media.AudioFormat
import android.media.AudioPlaybackCaptureConfiguration
import android.media.AudioRecord
import android.media.projection.MediaProjection
import android.os.Build
import android.view.Surface
import android.view.WindowManager
import com.ismartcoding.lib.logcat.LogCat
import com.ismartcoding.plain.data.DScreenMirrorQuality
import com.ismartcoding.plain.enums.AppFeatureType
import com.ismartcoding.plain.enums.ScreenMirrorMode
import com.ismartcoding.plain.web.websocket.WebRtcSignalingMessage
import org.webrtc.AudioSource
import org.webrtc.AudioTrack
import org.webrtc.DefaultVideoDecoderFactory
import org.webrtc.DefaultVideoEncoderFactory
import org.webrtc.EglBase
import org.webrtc.MediaConstraints
import org.webrtc.PeerConnectionFactory
import org.webrtc.SurfaceTextureHelper
import org.webrtc.VideoSource
import org.webrtc.VideoTrack
import org.webrtc.audio.JavaAudioDeviceModule
import kotlin.math.max
import kotlin.math.min

/**
 * Manages the shared screen-capture resources (MediaProjection, VirtualDisplay,
 * VideoSource, VideoTrack) and a set of [WebRtcPeerSession]s — one per connected
 * web client.
 *
 * [initCapture] is called exactly once from `ScreenMirrorService.onStartCommand()`
 * with the [MediaProjection] obtained from the one-time-use permission intent.
 * Subsequent orientation or quality changes are handled by [VirtualDisplay.resize],
 * which avoids re-creating the MediaProjection.
 */
class ScreenMirrorWebRtcManager(
    private val context: Context,
    private val getQuality: () -> DScreenMirrorQuality,
    private val getIsPortrait: () -> Boolean,
) {
    // ── Shared capture resources ──────────────────────────────────────────
    private var peerConnectionFactory: PeerConnectionFactory? = null
    private var videoSource: VideoSource? = null
    private var videoTrack: VideoTrack? = null
    private var audioSource: AudioSource? = null
    private var audioTrack: AudioTrack? = null
    private var audioDeviceModule: JavaAudioDeviceModule? = null
    @Volatile
    private var audioSwapped = false
    private var surfaceTextureHelper: SurfaceTextureHelper? = null
    private var eglBase: EglBase? = null

    // ── MediaProjection + VirtualDisplay (created once, resized as needed) ─
    private var mediaProjection: MediaProjection? = null
    private var virtualDisplay: VirtualDisplay? = null
    private var displaySurface: Surface? = null

    // ── Per-client peer sessions ──────────────────────────────────────────
    private val peerSessions = mutableMapOf<String, WebRtcPeerSession>()

    // ── Adaptive quality state (AUTO mode) ────────────────────────────────
    private var adaptiveResolution: Int = 1080
    private var statsHandler: android.os.Handler? = null
    private val statsIntervalMs = 3000L

    // ── Frame-rate limiter (isScreencast=true disables adaptOutputFormat) ─
    private var lastFrameTimeNs = 0L
    private val targetFps = 30
    private val minFrameIntervalNs = 1_000_000_000L / targetFps

    // ── Public API ────────────────────────────────────────────────────────

    /**
     * Initialise capture using the [MediaProjection] obtained from the system.
     * Creates a [VirtualDisplay] that renders screen content into a WebRTC
     * [VideoTrack].  Must be called exactly once.
     */
    fun initCapture(projection: MediaProjection) {
        if (virtualDisplay != null) {
            LogCat.d("webrtc: capture already initialised, skipping")
            return
        }

        mediaProjection = projection
        ensurePeerConnectionFactory(projection)
        projection.registerCallback(object : MediaProjection.Callback() {
            override fun onStop() {
                LogCat.d("webrtc: MediaProjection stopped")
            }
        }, null)

        val egl = eglBase ?: return
        val factory = peerConnectionFactory ?: return

        surfaceTextureHelper = SurfaceTextureHelper.create("ScreenCaptureThread", egl.eglBaseContext)
        videoSource = factory.createVideoSource(/* isScreencast = */ true)
        videoTrack = factory.createVideoTrack("screen_video", videoSource)

        // Create VirtualDisplay → Surface(SurfaceTexture) → SurfaceTextureHelper → VideoSource
        val (width, height) = computeCaptureSize()
        val dpi = context.resources.displayMetrics.densityDpi

        surfaceTextureHelper!!.setTextureSize(width, height)
        displaySurface = Surface(surfaceTextureHelper!!.surfaceTexture)

        virtualDisplay = projection.createVirtualDisplay(
            "WebRTC_ScreenCapture",
            width, height, dpi,
            DisplayManager.VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR,
            displaySurface,
            null, null,
        )

        // Start forwarding frames: SurfaceTextureHelper → VideoSource
        // Cap at targetFps to prevent encoder overload and latency build-up.
        // (adaptOutputFormat is a no-op when isScreencast=true, so we drop frames manually.)
        surfaceTextureHelper!!.startListening { frame ->
            val now = System.nanoTime()
            if (now - lastFrameTimeNs >= minFrameIntervalNs) {
                lastFrameTimeNs = now
                videoSource!!.capturerObserver.onFrameCaptured(frame)
            }
            // Skipped frames are automatically released by SurfaceTextureHelper
        }
        videoSource!!.capturerObserver.onCapturerStarted(true)

        LogCat.d("webrtc: VirtualDisplay created ${width}x${height} dpi=$dpi")

        // Create audio source and track (JavaAudioDeviceModule handles actual capture)
        audioSource = factory.createAudioSource(MediaConstraints())
        audioTrack = factory.createAudioTrack("screen_audio", audioSource)
        audioTrack?.setEnabled(true)
        LogCat.d("webrtc: audio track created, enabled=${audioTrack?.enabled()}")
    }

    /**
     * Swap the internal mic-based AudioRecord inside WebRtcAudioRecord (via reflection)
     * with one that captures system audio using AudioPlaybackCaptureConfiguration.
     * Called on the WebRTC audio recording thread from AudioRecordStateCallback.
     */
    @SuppressLint("MissingPermission")
    private fun swapToPlaybackCapture(projection: MediaProjection) {
        if (!AppFeatureType.MIRROR_AUDIO.has()) {
            LogCat.d("webrtc: audio swap skipped, API < Q")
            return
        }
        if (audioSwapped) {
            LogCat.d("webrtc: audio swap already done")
            return
        }

        // Check RECORD_AUDIO permission at runtime (required for AudioPlaybackCapture)
        if (context.checkSelfPermission(android.Manifest.permission.RECORD_AUDIO)
            != android.content.pm.PackageManager.PERMISSION_GRANTED) {
            LogCat.e("webrtc: RECORD_AUDIO permission not granted, cannot capture audio")
            return
        }

        try {
            val adm = audioDeviceModule ?: run {
                LogCat.e("webrtc: audioDeviceModule is null")
                return
            }

            // Access JavaAudioDeviceModule.audioInput (WebRtcAudioRecord)
            LogCat.d("webrtc: audio swap step 1 - accessing audioInput field")
            val audioInputField = adm.javaClass.getDeclaredField("audioInput")
            audioInputField.isAccessible = true
            val audioInput = audioInputField.get(adm) ?: run {
                LogCat.e("webrtc: audioInput is null")
                return
            }

            // Access WebRtcAudioRecord.audioRecord (android.media.AudioRecord)
            LogCat.d("webrtc: audio swap step 2 - accessing audioRecord field from ${audioInput.javaClass.name}")
            val audioRecordField = audioInput.javaClass.getDeclaredField("audioRecord")
            audioRecordField.isAccessible = true
            val oldRecord = audioRecordField.get(audioInput) as? AudioRecord ?: run {
                LogCat.e("webrtc: audioRecord is null or not AudioRecord")
                return
            }

            // Read params from the existing AudioRecord to match WebRTC's expectations
            val sampleRate = oldRecord.sampleRate
            val channelCount = oldRecord.channelCount
            val channelConfig = if (channelCount == 1) AudioFormat.CHANNEL_IN_MONO else AudioFormat.CHANNEL_IN_STEREO
            val encoding = oldRecord.audioFormat
            LogCat.d("webrtc: audio swap step 3 - old record params: rate=$sampleRate ch=$channelCount encoding=$encoding state=${oldRecord.state}")

            // Stop & release the mic-based AudioRecord
            try { oldRecord.stop() } catch (e: Exception) {
                LogCat.d("webrtc: old record stop exception (expected): ${e.message}")
            }
            oldRecord.release()
            LogCat.d("webrtc: audio swap step 4 - old record released")

            // Create a new AudioRecord that captures system audio via MediaProjection
            val playbackConfig = AudioPlaybackCaptureConfiguration.Builder(projection)
                .addMatchingUsage(AudioAttributes.USAGE_MEDIA)
                .addMatchingUsage(AudioAttributes.USAGE_GAME)
                .addMatchingUsage(AudioAttributes.USAGE_UNKNOWN)
                .build()

            val bufferSize = AudioRecord.getMinBufferSize(sampleRate, channelConfig, encoding) * 2
            LogCat.d("webrtc: audio swap step 5 - creating playback capture AudioRecord (bufSize=$bufferSize)")
            val newRecord = AudioRecord.Builder()
                .setAudioPlaybackCaptureConfig(playbackConfig)
                .setAudioFormat(
                    AudioFormat.Builder()
                        .setSampleRate(sampleRate)
                        .setChannelMask(channelConfig)
                        .setEncoding(encoding)
                        .build()
                )
                .setBufferSizeInBytes(bufferSize)
                .build()

            if (newRecord.state != AudioRecord.STATE_INITIALIZED) {
                LogCat.e("webrtc: Playback-capture AudioRecord failed to initialise (state=${newRecord.state})")
                newRecord.release()
                return
            }

            // Replace the field and start the new AudioRecord
            audioRecordField.set(audioInput, newRecord)
            newRecord.startRecording()
            audioSwapped = true

            LogCat.d("webrtc: audio swap DONE - system audio capture active (rate=$sampleRate ch=$channelCount recordingState=${newRecord.recordingState})")
        } catch (e: Exception) {
            LogCat.e("webrtc: Failed to swap to playback capture: ${e.javaClass.simpleName}: ${e.message}")
            e.printStackTrace()
        }
    }

    private fun releaseAudioCapture() {
        audioTrack = null
        audioSource?.dispose()
        audioSource = null
    }

    fun handleSignaling(clientId: String, message: WebRtcSignalingMessage) {
        when (message.type) {
            "ready" -> {
                LogCat.d("webrtc: ready from $clientId")
                val factory = peerConnectionFactory
                val track = videoTrack
                if (factory == null || track == null) {
                    LogCat.e("webrtc: capturer not initialised, ignoring ready")
                    return
                }

                // Tear down any previous session for this client (re-negotiation).
                peerSessions.remove(clientId)?.release()

                val session = WebRtcPeerSession(clientId, factory, track, audioTrack, { computeTargetBitrateKbps() }, { getQuality().mode })
                peerSessions[clientId] = session
                session.createPeerConnectionAndOffer()

                if (getQuality().mode == ScreenMirrorMode.AUTO) {
                    startStatsMonitoring()
                }
            }

            "answer" -> {
                if (!message.sdp.isNullOrBlank()) {
                    peerSessions[clientId]?.handleAnswer(message.sdp)
                }
            }

            "ice_candidate" -> {
                if (!message.candidate.isNullOrBlank()) {
                    peerSessions[clientId]?.handleIceCandidate(message)
                }
            }

            else -> {
                LogCat.d("webrtc: ignore signaling type=${message.type}")
            }
        }
    }

    fun onQualityChanged() {
        val quality = getQuality()
        if (quality.mode == ScreenMirrorMode.AUTO) {
            adaptiveResolution = 1080
            startStatsMonitoring()
        } else {
            stopStatsMonitoring()
        }
        resizeVirtualDisplay()
        peerSessions.values.forEach { it.updateVideoBitrate() }
    }

    fun onOrientationChanged() {
        resizeVirtualDisplay()
    }

    fun removeClient(clientId: String) {
        peerSessions.remove(clientId)?.release()
    }

    fun releaseAll() {
        stopStatsMonitoring()
        peerSessions.values.forEach { it.release() }
        peerSessions.clear()

        releaseAudioCapture()
        audioDeviceModule?.release()
        audioDeviceModule = null
        audioSwapped = false

        virtualDisplay?.release()
        virtualDisplay = null

        displaySurface?.release()
        displaySurface = null

        surfaceTextureHelper?.stopListening()
        videoSource?.capturerObserver?.onCapturerStopped()

        mediaProjection?.stop()
        mediaProjection = null

        surfaceTextureHelper?.dispose()
        surfaceTextureHelper = null

        videoTrack = null
        videoSource?.dispose()
        videoSource = null

        peerConnectionFactory?.dispose()
        peerConnectionFactory = null

        eglBase?.release()
        eglBase = null
    }

    // ── Private helpers ───────────────────────────────────────────────────

    private fun ensurePeerConnectionFactory(projection: MediaProjection) {
        if (peerConnectionFactory != null) return

        if (!webrtcInitialized) {
            PeerConnectionFactory.initialize(
                PeerConnectionFactory.InitializationOptions.builder(context)
                    .setEnableInternalTracer(false)
                    .createInitializationOptions(),
            )
            webrtcInitialized = true
        }

        eglBase = EglBase.create()

        // Create JavaAudioDeviceModule for system audio capture.
        // Disable HW AEC/NS (not needed for system audio) and register a
        // state callback to swap the internal mic AudioRecord with one
        // using AudioPlaybackCaptureConfiguration once recording starts.
        val adm = JavaAudioDeviceModule.builder(context)
            .setUseHardwareAcousticEchoCanceler(false)
            .setUseHardwareNoiseSuppressor(false)
            .setAudioRecordStateCallback(object : JavaAudioDeviceModule.AudioRecordStateCallback {
                override fun onWebRtcAudioRecordStart() {
                    swapToPlaybackCapture(projection)
                }
                override fun onWebRtcAudioRecordStop() {
                    audioSwapped = false
                }
            })
            .createAudioDeviceModule()
        audioDeviceModule = adm

        val encoderFactory = DefaultVideoEncoderFactory(eglBase!!.eglBaseContext, true, true)
        val decoderFactory = DefaultVideoDecoderFactory(eglBase!!.eglBaseContext)
        peerConnectionFactory = PeerConnectionFactory.builder()
            .setVideoEncoderFactory(encoderFactory)
            .setVideoDecoderFactory(decoderFactory)
            .setAudioDeviceModule(adm)
            .createPeerConnectionFactory()

        // Keep the ADM Java object alive so that our AudioRecordStateCallback
        // can use reflection to swap the internal AudioRecord later.
    }

    /**
     * Resize the existing [VirtualDisplay] to match the current quality / orientation.
     * No need to recreate the MediaProjection or VirtualDisplay.
     */
    private fun resizeVirtualDisplay() {
        val vd = virtualDisplay ?: return
        val (width, height) = computeCaptureSize()
        val dpi = context.resources.displayMetrics.densityDpi

        surfaceTextureHelper?.setTextureSize(width, height)
        vd.resize(width, height, dpi)

        LogCat.d("webrtc: VirtualDisplay resized ${width}x${height} dpi=$dpi")
    }

    private fun computeCaptureSize(): Pair<Int, Int> {
        val realSize = getRealScreenSize()
        val width = realSize.x
        val height = realSize.y

        val shortSide = min(width, height)
        val targetShort = getEffectiveResolution()
        val scale = min(1f, targetShort.toFloat() / shortSide.toFloat())

        val targetWidth = makeEven(max(2, (width * scale).toInt()))
        val targetHeight = makeEven(max(2, (height * scale).toInt()))

        getIsPortrait()

        return Pair(targetWidth, targetHeight)
    }

    /**
     * Get the real physical screen dimensions including system bars.
     * displayMetrics.widthPixels/heightPixels may exclude the navigation bar
     * on Android <= 11, causing wrong capture aspect ratio and black bars.
     */
    private fun getRealScreenSize(): Point {
        val wm = context.getSystemService(Context.WINDOW_SERVICE) as WindowManager
        return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
            val bounds = wm.currentWindowMetrics.bounds
            Point(bounds.width(), bounds.height())
        } else {
            val size = Point()
            @Suppress("DEPRECATION")
            wm.defaultDisplay.getRealSize(size)
            size
        }
    }

    private fun getEffectiveResolution(): Int {
        val quality = getQuality()
        return when (quality.mode) {
            ScreenMirrorMode.AUTO -> adaptiveResolution
            ScreenMirrorMode.HD -> 1080
            ScreenMirrorMode.SMOOTH -> 720
        }
    }

    private fun computeTargetBitrateKbps(): Int {
        val resolution = getEffectiveResolution()
        // LAN bitrates for screen content (sharp text/UI edges)
        // resolution = short side, so 1080p on a 20:9 phone ≈ 1080×2400
        return when {
            resolution >= 1080 -> 20000
            resolution >= 720 -> 10000
            else -> 4000
        }
    }

    // ── Adaptive stats monitoring (AUTO mode) ─────────────────────────────

    private fun startStatsMonitoring() {
        stopStatsMonitoring()
        if (getQuality().mode != ScreenMirrorMode.AUTO) return

        statsHandler = android.os.Handler(android.os.Looper.getMainLooper())
        statsHandler?.postDelayed(object : Runnable {
            override fun run() {
                if (getQuality().mode != ScreenMirrorMode.AUTO) return
                pollStatsAndAdapt()
                statsHandler?.postDelayed(this, statsIntervalMs)
            }
        }, statsIntervalMs)
    }

    private fun stopStatsMonitoring() {
        statsHandler?.removeCallbacksAndMessages(null)
        statsHandler = null
    }

    private fun pollStatsAndAdapt() {
        val session = peerSessions.values.firstOrNull() ?: return
        session.getStats { availableBitrateKbps, packetLossPercent, rttMs ->
            val oldResolution = adaptiveResolution

            // Downgrade: high packet loss or low available bitrate
            val shouldDowngrade = packetLossPercent > 5.0 || rttMs > 150 ||
                    (availableBitrateKbps in 1 until 8000)
            // Upgrade: plenty of bandwidth and good network
            val shouldUpgrade = availableBitrateKbps > 15000 && packetLossPercent < 1.0 && rttMs < 50

            if (shouldDowngrade && adaptiveResolution > 720) {
                adaptiveResolution = 720
            } else if (shouldUpgrade && adaptiveResolution < 1080) {
                adaptiveResolution = 1080
            }

            if (oldResolution != adaptiveResolution) {
                LogCat.d("webrtc: adaptive resolution $oldResolution → $adaptiveResolution " +
                    "(bw=${availableBitrateKbps}kbps loss=${String.format("%.1f", packetLossPercent)}% rtt=${String.format("%.0f", rttMs)}ms)")
                resizeVirtualDisplay()
                peerSessions.values.forEach { it.updateVideoBitrate() }
            }
        }
    }

    private fun makeEven(value: Int): Int = if (value % 2 == 0) value else value - 1

    companion object {
        private var webrtcInitialized = false
    }
}
