/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.tracecompass.incubator.internal.rocm.core.analysis;

import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.eclipse.jdt.annotation.NonNull;
import org.eclipse.tracecompass.analysis.os.linux.core.model.HostThread;
import org.eclipse.tracecompass.incubator.callstack.core.base.EdgeStateValue;
import org.eclipse.tracecompass.statesystem.core.ITmfStateSystemBuilder;
import org.eclipse.tracecompass.tmf.core.event.ITmfEvent;
import org.eclipse.tracecompass.tmf.core.event.ITmfEventField;
import org.eclipse.tracecompass.tmf.core.statesystem.AbstractTmfStateProvider;
import org.eclipse.tracecompass.tmf.core.statesystem.ITmfStateProvider;
import org.eclipse.tracecompass.tmf.core.trace.ITmfTrace;

public class RocmCtfCallStackStateProvider
extends AbstractTmfStateProvider {
    private static final String ID = "org.eclipse.tracecompass.incubator.rocm.ctf.callstackstateprovider";
    static final @NonNull String EDGES_LANE = "EDGES";
    final List<Long> fCurrentKernelDispatched = new LinkedList<Long>();
    Map<Long, ITmfEvent> fHipKernelDispatchs = new HashMap<Long, ITmfEvent>();
    Long fHipDispatchCounter = 0L;
    private static final Map<String, Integer> fApiId = new HashMap<String, Integer>();

    static {
        fApiId.put("hip_api", 1);
        fApiId.put("hsa_api", 2);
        fApiId.put("kfd_api", 3);
    }

    public RocmCtfCallStackStateProvider(@NonNull ITmfTrace trace) {
        super(trace, ID);
    }

    public int getVersion() {
        return 1;
    }

    public ITmfStateProvider getNewInstance() {
        return new RocmCtfCallStackStateProvider(this.getTrace());
    }

    protected void eventHandle(ITmfEvent event) {
        ITmfEventField content = event.getContent();
        if (content == null) {
            return;
        }
        ITmfStateSystemBuilder ssb = this.getStateSystemBuilder();
        if (ssb == null) {
            return;
        }
        int quark = RocmCtfCallStackStateProvider.getCorrectQuark(ssb, event);
        if (quark == -1) {
            return;
        }
        if (event.getName().equals("compute_kernels_hsa")) {
            this.processGpuEvent(event, ssb, quark);
        } else if (event.getName().equals("async_copy")) {
            RocmCtfCallStackStateProvider.processMemoryCopies(event, ssb, quark);
        } else {
            this.processApiEvent(event, ssb, quark);
        }
    }

    private static void processMemoryCopies(ITmfEvent event, ITmfStateSystemBuilder ssb, int quark) {
        ITmfEventField content = event.getContent();
        long timestamp = event.getTimestamp().toNanos();
        String eventName = (String)content.getFieldValue(String.class, new String[]{"name"});
        if (eventName == null) {
            ssb.modifyAttribute(timestamp, null, quark);
            return;
        }
        if (eventName.endsWith("_exit")) {
            ssb.popAttribute(timestamp, quark);
        } else {
            ssb.pushAttribute(timestamp, (Object)eventName.substring(0, eventName.length() - 6), quark);
        }
    }

    private void processGpuEvent(ITmfEvent event, ITmfStateSystemBuilder ssb, int callStackQuark) {
        ITmfEventField content = event.getContent();
        long timestamp = event.getTimestamp().toNanos();
        String eventName = (String)content.getField(new String[]{"kernel_name"}).getValue();
        Long eventDispatchId = (Long)content.getFieldValue(Long.class, new String[]{"kernel_dispatch_id"});
        if (eventDispatchId == null) {
            return;
        }
        Long gpuId = (Long)event.getContent().getFieldValue(Long.class, new String[]{"gpu_id"});
        this.updateGpuGapState(ssb, timestamp, gpuId, eventDispatchId);
        if (this.fCurrentKernelDispatched.remove(eventDispatchId)) {
            ssb.popAttribute(timestamp, callStackQuark);
            ITmfEvent hipEvent = this.fHipKernelDispatchs.remove(eventDispatchId);
            if (hipEvent == null) {
                return;
            }
            int hipStreamCallStackQuark = RocmCtfCallStackStateProvider.getHipStreamCallStackQuark(ssb, hipEvent, gpuId);
            ssb.popAttribute(timestamp, hipStreamCallStackQuark);
        } else {
            Long tid;
            ITmfEvent hipEvent;
            this.fCurrentKernelDispatched.add(eventDispatchId);
            ssb.pushAttribute(timestamp, (Object)eventName, callStackQuark);
            if (ssb.queryOngoing(callStackQuark) != null && gpuId != null) {
                int parentQuark = ssb.getParentAttributeQuark(callStackQuark);
                ssb.modifyAttribute(this.getTrace().getStartTime().toNanos(), (Object)(gpuId.intValue() * 2), parentQuark);
            }
            if ((hipEvent = this.fHipKernelDispatchs.get(eventDispatchId)) == null) {
                return;
            }
            int hipStreamCallStackQuark = RocmCtfCallStackStateProvider.getHipStreamCallStackQuark(ssb, hipEvent, gpuId);
            ssb.pushAttribute(timestamp, (Object)eventName, hipStreamCallStackQuark);
            if (ssb.queryOngoing(hipStreamCallStackQuark) != null && gpuId != null) {
                int parentQuark = ssb.getParentAttributeQuark(hipStreamCallStackQuark);
                ssb.modifyAttribute(this.getTrace().getStartTime().toNanos(), (Object)(gpuId.intValue() * 2 + 1), parentQuark);
            }
            if ((tid = (Long)hipEvent.getContent().getFieldValue(Long.class, new String[]{"tid"})) != null && gpuId != null) {
                HostThread src = new HostThread(event.getTrace().getHostId(), Integer.valueOf(Math.toIntExact(tid) * fApiId.getOrDefault(hipEvent.getName(), 4)));
                HostThread destQueue = new HostThread(event.getTrace().getHostId(), Integer.valueOf(gpuId.intValue() * 2));
                HostThread destStream = new HostThread(event.getTrace().getHostId(), Integer.valueOf(gpuId.intValue() * 2 + 1));
                RocmCtfCallStackStateProvider.addArrow(ssb, hipEvent.getTimestamp().getValue(), timestamp, Math.toIntExact(eventDispatchId), src, destQueue);
                RocmCtfCallStackStateProvider.addArrow(ssb, hipEvent.getTimestamp().getValue(), timestamp, Math.toIntExact(eventDispatchId), src, destStream);
            }
        }
    }

    private void updateGpuGapState(ITmfStateSystemBuilder ssb, Long timestamp, Long gpuId, Long dispatchId) {
        int gpuQuark = ssb.getQuarkAbsoluteAndAdd(new String[]{"Processes", "GPU " + gpuId.toString()});
        boolean isNewGap = ssb.optQuarkRelative(gpuQuark, new String[]{""}) == -2;
        int interQuark = ssb.getQuarkRelativeAndAdd(gpuQuark, new String[]{""});
        int gapQuark = ssb.getQuarkRelativeAndAdd(interQuark, new String[]{"Gap Analysis"});
        int callStackQuark = ssb.getQuarkRelativeAndAdd(gapQuark, new String[]{"CallStack"});
        if (isNewGap) {
            ssb.pushAttribute(this.getTrace().getStartTime().getValue(), (Object)"Idle", callStackQuark);
        }
        if (this.fCurrentKernelDispatched.contains(dispatchId)) {
            ssb.pushAttribute(timestamp.longValue(), (Object)"Idle", callStackQuark);
        } else {
            ssb.popAttribute(timestamp.longValue(), callStackQuark);
        }
    }

    private static void addArrow(ITmfStateSystemBuilder ssb, Long startTime, Long endTime, int id, @NonNull HostThread src, @NonNull HostThread dest) {
        int edgeQuark = RocmCtfCallStackStateProvider.getAvailableEdgeQuark(ssb, startTime);
        EdgeStateValue edgeStateValue = new EdgeStateValue(id, src, dest);
        ssb.modifyAttribute(startTime.longValue(), (Object)edgeStateValue, edgeQuark);
        ssb.modifyAttribute(endTime.longValue(), null, edgeQuark);
    }

    /*
     * Issues handling annotations - annotations may be inaccurate
     */
    private static int getAvailableEdgeQuark(ITmfStateSystemBuilder ssb, long startTime) {
        int edgeRoot = ssb.getQuarkAbsoluteAndAdd(new String[]{"Edges"});
        @NonNull List subQuarks = ssb.getSubAttributes(edgeRoot, false);
        Iterator iterator = subQuarks.iterator();
        while (iterator.hasNext()) {
            int quark = (Integer)iterator.next();
            long start = ssb.getOngoingStartTime(quark);
            Object value = ssb.queryOngoing(quark);
            if (value != null || start > startTime) continue;
            return quark;
        }
        return ssb.getQuarkRelativeAndAdd(edgeRoot, new String[]{Integer.toString(subQuarks.size())});
    }

    private void processApiEvent(ITmfEvent event, ITmfStateSystemBuilder ssb, int callStackQuark) {
        ITmfEventField content = event.getContent();
        long timestamp = event.getTimestamp().toNanos();
        String eventName = (String)content.getFieldValue(String.class, new String[]{"name"});
        if (eventName == null) {
            ssb.popAttribute(timestamp, callStackQuark);
            return;
        }
        Long tid = (Long)content.getFieldValue(Long.class, new String[]{"tid"});
        if (ssb.queryOngoing(callStackQuark) != null && tid != null) {
            int parentQuark = ssb.getParentAttributeQuark(callStackQuark);
            ssb.modifyAttribute(this.getTrace().getStartTime().toNanos(), (Object)(tid.intValue() * fApiId.getOrDefault(event.getName(), 4)), parentQuark);
        }
        if (eventName.equals("hipLaunchKernel_enter")) {
            this.fHipKernelDispatchs.put(this.fHipDispatchCounter, event);
            this.fHipDispatchCounter = this.fHipDispatchCounter + 1L;
        }
        if (eventName.endsWith("_exit")) {
            ssb.popAttribute(timestamp, callStackQuark);
        } else {
            ssb.pushAttribute(timestamp, (Object)eventName.substring(0, eventName.length() - 6), callStackQuark);
        }
    }

    private static int getCorrectQuark(ITmfStateSystemBuilder ssb, @NonNull ITmfEvent event) {
        switch (event.getName()) {
            case "kfd_api": 
            case "hip_api": 
            case "hsa_api": {
                return RocmCtfCallStackStateProvider.getApiCallStackQuark(ssb, event);
            }
            case "async_copy": 
            case "roctx": 
            case "compute_kernels_hsa": 
            case "hcc_ops": {
                return RocmCtfCallStackStateProvider.getGpuActivityCallStackQuark(ssb, event);
            }
        }
        return -1;
    }

    private static int getGpuActivityCallStackQuark(ITmfStateSystemBuilder ssb, @NonNull ITmfEvent event) {
        if (event.getName().equals("compute_kernels_hsa")) {
            Long queueId = (Long)event.getContent().getFieldValue(Long.class, new String[]{"queue_id"});
            Long gpuId = (Long)event.getContent().getFieldValue(Long.class, new String[]{"gpu_id"});
            if (queueId == null || gpuId == null) {
                return -1;
            }
            int gpuQuark = ssb.getQuarkAbsoluteAndAdd(new String[]{"Processes", "GPU " + gpuId.toString()});
            int queuesQuark = ssb.getQuarkRelativeAndAdd(gpuQuark, new String[]{"Queues"});
            int queueQuark = ssb.getQuarkRelativeAndAdd(queuesQuark, new String[]{"Queue " + Long.toString(queueId)});
            int callStackQuark = ssb.getQuarkRelativeAndAdd(queueQuark, new String[]{"CallStack"});
            return callStackQuark;
        }
        if (event.getName().equals("hcc_ops")) {
            int gpuActivity = ssb.getQuarkAbsoluteAndAdd(new String[]{"Processes", "GPU Activity"});
            int gpuQuark = ssb.getQuarkRelativeAndAdd(gpuActivity, new String[]{"GPU Kernels"});
            int callStackQuark = ssb.getQuarkRelativeAndAdd(gpuQuark, new String[]{"CallStack"});
            return callStackQuark;
        }
        int copyQuark = ssb.getQuarkAbsoluteAndAdd(new String[]{"Processes", "Memory"});
        int tempQuark1 = ssb.getQuarkRelativeAndAdd(copyQuark, new String[]{""});
        int tempQuark2 = ssb.getQuarkRelativeAndAdd(tempQuark1, new String[]{"Memory Transfers"});
        int callStackQuark = ssb.getQuarkRelativeAndAdd(tempQuark2, new String[]{"CallStack"});
        return callStackQuark;
    }

    private static int getApiCallStackQuark(ITmfStateSystemBuilder ssb, @NonNull ITmfEvent event) {
        int systemQuark = ssb.getQuarkAbsoluteAndAdd(new String[]{"Processes", "System"});
        Long threadId = (Long)event.getContent().getFieldValue(Long.class, new String[]{"tid"});
        if (threadId == null) {
            threadId = 0L;
        }
        int threadQuark = ssb.getQuarkRelativeAndAdd(systemQuark, new String[]{"Thread " + threadId.toString()});
        int apiQuark = ssb.getQuarkRelativeAndAdd(threadQuark, new String[]{event.getName().toUpperCase()});
        int callStackQuark = ssb.getQuarkRelativeAndAdd(apiQuark, new String[]{"CallStack"});
        return callStackQuark;
    }

    private static int getHipStreamCallStackQuark(ITmfStateSystemBuilder ssb, @NonNull ITmfEvent event, Long gpuId) {
        int gpuQuark = ssb.getQuarkAbsoluteAndAdd(new String[]{"Processes", "GPU " + gpuId.toString()});
        int hipStreamsQuark = ssb.getQuarkRelativeAndAdd(gpuQuark, new String[]{"HIP Streams"});
        String args = (String)event.getContent().getFieldValue(String.class, new String[]{"args"});
        Pattern p = Pattern.compile("stream\\((\\d*)\\)");
        Matcher m = p.matcher(args);
        int callStackQuark = 0;
        if (m.find()) {
            int hipStreamId = Integer.parseInt(m.group(1));
            int hipStreamQuark = ssb.getQuarkRelativeAndAdd(hipStreamsQuark, new String[]{"Stream " + Integer.toString(hipStreamId)});
            callStackQuark = ssb.getQuarkRelativeAndAdd(hipStreamQuark, new String[]{"CallStack"});
        }
        return callStackQuark;
    }
}

