/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.profiler.comparison;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.profiler.ProfilingListener;
import org.nd4j.autodiff.listeners.profiler.comparison.Config;
import org.nd4j.autodiff.listeners.profiler.comparison.OpStats;
import org.nd4j.autodiff.listeners.profiler.data.Phase;
import org.nd4j.autodiff.listeners.profiler.data.TraceEvent;
import org.nd4j.autodiff.listeners.profiler.data.TraceEvents;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.list.NDArrayList;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ProfileAnalyzer {
    private static final Logger log = LoggerFactory.getLogger(ProfileAnalyzer.class);
    private static Map<String, String> TF_PROFILE_ALIASES = new HashMap<String, String>();

    public static void summarizeProfile(File file, ProfileFormat profileFormat) {
        System.out.println(ProfileAnalyzer.summarizeProfileStr(file, profileFormat));
    }

    public static String summarizeProfileStr(File file, ProfileFormat profileFormat) {
        TraceEvent[] events = ProfileAnalyzer.getTraceEvents(file, profileFormat);
        return ProfileAnalyzer.summarizeTraceEvents(events);
    }

    public static void summarizeProfileDirectory(File dir, ProfileFormat profileFormat) {
        System.out.println(ProfileAnalyzer.summarizeProfileDirectoryStr(dir, profileFormat));
    }

    public static String summarizeProfileDirectoryStr(File dir, ProfileFormat profileFormat) {
        return ProfileAnalyzer.summarizeTraceEvents(ProfileAnalyzer.getTraceEventsDir(dir, profileFormat));
    }

    public static TraceEvent[] getTraceEventsDir(File dir, ProfileFormat profileFormat) {
        File[] files = dir.listFiles();
        Preconditions.checkState(files != null && files.length > 0, "No profiles found in directory: %s", (Object)dir);
        ArrayList l = new ArrayList();
        for (File f : files) {
            if (!f.getName().endsWith(".json")) {
                log.info("Skipping non-JSON file in directory - {}", (Object)f.getAbsolutePath());
                continue;
            }
            TraceEvent[] e = ProfileAnalyzer.getTraceEvents(f, profileFormat);
            Collections.addAll(l, e);
        }
        return l.toArray(new TraceEvent[0]);
    }

    public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat) {
        return ProfileAnalyzer.getTraceEvents(file, profileFormat, true);
    }

    /*
     * WARNING - void declaration
     */
    public static TraceEvent[] getTraceEvents(File file, ProfileFormat profileFormat, boolean aggregateTFSubOps) {
        TraceEvents traceEvents;
        TraceEvent[] events;
        ObjectMapper json = ProfilingListener.jsonMapper();
        String content = null;
        try (BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));){
            try {
                content = IOUtils.toString((InputStream)bufferedInputStream, (Charset)Charset.defaultCharset());
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        if (!content.matches(".*]\\s*")) {
            content = content.endsWith(",") ? content.substring(0, content.length() - 1) + "]" : (content.endsWith(",\n") ? content.substring(0, content.length() - 2) + "]" : content + "]");
        }
        if (profileFormat == ProfileFormat.SAMEDIFF) {
            try {
                events = json.readValue(content, TraceEvent[].class);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        try {
            traceEvents = json.readValue(content, TraceEvents.class);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        for (TraceEvent traceEvent : events = traceEvents.getTraceEvents().toArray(new TraceEvent[0])) {
            DifferentialFunction df;
            if (TF_PROFILE_ALIASES.containsKey(traceEvent.getName())) {
                traceEvent.setName(TF_PROFILE_ALIASES.get(traceEvent.getName()));
            }
            if ((df = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(traceEvent.getName())) == null) continue;
            traceEvent.setName(df.opName());
        }
        if (aggregateTFSubOps) {
            void var10_24;
            HashMap<String, TraceEvent> map = new HashMap<String, TraceEvent>();
            ArrayList<TraceEvent> out = new ArrayList<TraceEvent>();
            TraceEvent last = null;
            for (TraceEvent te : events) {
                if (last != null && last.getPh() == Phase.X && te.getPh() == Phase.X && last.getName().equals(te.getName()) && last.getArgs() != null && te.getArgs() != null && last.getArgs().get("name").equals(te.getArgs().get("name")) && last.getArgs().get("op").equals(te.getArgs().get("op"))) {
                    last.setDur(last.getDur() + te.getDur());
                    continue;
                }
                last = te;
                if (te.getArgs() == null || te.getArgs().isEmpty()) {
                    out.add(te);
                    continue;
                }
                String n = (String)te.getArgs().get("name");
                if (n.matches("[\\w/_-]+:[\\w/_-]+#id=\\d+.*")) {
                    int idx = n.indexOf("#");
                    String sub1 = n.substring(0, idx);
                    String sub = sub1.contains(":") ? sub1.substring(0, sub1.lastIndexOf(":")) : sub1;
                    if (map.containsKey(sub)) {
                        TraceEvent t = (TraceEvent)map.get(sub);
                        Long dur = t.getDur();
                        if (dur == null && te.getDur() == null) continue;
                        t.setDur(dur == null ? te.getDur() : dur + (te.getDur() == null ? 0L : te.getDur()));
                        continue;
                    }
                    map.put(sub, te);
                    out.add(te);
                    continue;
                }
                if (map.containsKey(n)) {
                    TraceEvent t = (TraceEvent)map.get(n);
                    t.setDur(t.getDur() + te.getDur());
                    continue;
                }
                map.put(n, te);
                out.add(te);
            }
            boolean bl = false;
            while (var10_24 < out.size()) {
                String n;
                TraceEvent te = (TraceEvent)out.get((int)var10_24);
                if (te.getArgs() != null && !te.getArgs().isEmpty() && (n = (String)te.getArgs().get("name")).matches("[\\w/_-]+:[\\w/_-]+#id=\\d+.*")) {
                    int idx = n.indexOf(58);
                    String sub = n.substring(0, idx);
                    te.getArgs().put("name", sub);
                }
                ++var10_24;
            }
            events = out.toArray(new TraceEvent[0]);
        }
        return events;
    }

    public static String summarizeTraceEvents(TraceEvent[] events) {
        Pair<Long, Map<String, OpStats>> p = ProfileAnalyzer.aggregateTraceEvents(events);
        final Map<String, OpStats> stats = p.getSecond();
        long allOpsUs = p.getFirst();
        ArrayList<String> l = new ArrayList<String>(stats.keySet());
        Collections.sort(l, new Comparator<String>(){

            @Override
            public int compare(String o1, String o2) {
                return -Long.compare(((OpStats)stats.get(o1)).getSumUs(), ((OpStats)stats.get(o2)).getSumUs());
            }
        });
        int longestName = 30;
        int longestOpName = 30;
        for (String s : l) {
            longestName = Math.max(longestName, s.length() + 1);
            longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1);
        }
        StringBuilder sb = new StringBuilder();
        String headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
        sb.append(String.format(headerFormat, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
        String format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
        for (String s : l) {
            OpStats st = stats.get(s);
            double pc = 100.0 * (double)st.getSumUs().longValue() / (double)allOpsUs;
            INDArray arr = st.getTimesUs().array();
            long min = arr.minNumber().longValue();
            long max = arr.maxNumber().longValue();
            double std = arr.stdNumber().doubleValue();
            sb.append(String.format(format, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std));
        }
        return sb.toString();
    }

    private static Pair<Long, Map<String, OpStats>> aggregateTraceEvents(TraceEvent[] events) {
        HashMap<String, OpStats> stats = new HashMap<String, OpStats>();
        for (TraceEvent e : events) {
            OpStats s;
            if (e.getPh() != Phase.X || e.getDur() == null) continue;
            String instanceName = (String)e.getArgs().get("name");
            if (stats.containsKey(instanceName)) {
                s = (OpStats)stats.get(instanceName);
            } else {
                s = new OpStats(instanceName, e.getName(), 0, new NDArrayList(DataType.LONG, 0), null);
                stats.put(instanceName, s);
            }
            s.setCount(s.getCount() + 1);
            s.getTimesUs().add((double)e.getDur());
        }
        long allOpsUs = 0L;
        for (OpStats s : stats.values()) {
            s.setSumUs(s.getTimesUs().array().sumNumber().longValue());
            allOpsUs += s.getSumUs().longValue();
        }
        return new Pair<Long, Map<String, OpStats>>(allOpsUs, stats);
    }

    public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2) {
        if (file1 == null) {
            throw new NullPointerException("file1 is marked non-null but is null");
        }
        if (file2 == null) {
            throw new NullPointerException("file2 is marked non-null but is null");
        }
        if (format1 == null) {
            throw new NullPointerException("format1 is marked non-null but is null");
        }
        if (format2 == null) {
            throw new NullPointerException("format2 is marked non-null but is null");
        }
        return ProfileAnalyzer.compareProfiles(file1, file2, format1, format2, false, false, null, null, SortBy.PROFILE1_PC);
    }

    public static String compareProfiles(@NonNull File file1, @NonNull File file2, @NonNull ProfileFormat format1, @NonNull ProfileFormat format2, boolean firstIsDir, boolean secondIsDir, String name1, String name2, SortBy sortBy) {
        if (file1 == null) {
            throw new NullPointerException("file1 is marked non-null but is null");
        }
        if (file2 == null) {
            throw new NullPointerException("file2 is marked non-null but is null");
        }
        if (format1 == null) {
            throw new NullPointerException("format1 is marked non-null but is null");
        }
        if (format2 == null) {
            throw new NullPointerException("format2 is marked non-null but is null");
        }
        return ProfileAnalyzer.compareProfiles(Config.builder().profile1(file1).profile2(file2).profile1Format(format1).profile2Format(format2).profile1IsDir(firstIsDir).profile2IsDir(secondIsDir).p1Name(name1).p2Name(name2).sortBy(sortBy).build());
    }

    public static String compareProfiles(final Config c) {
        String format;
        String headerFormat;
        TraceEvent[] t1 = c.profile1IsDir() ? ProfileAnalyzer.getTraceEventsDir(c.profile1(), c.profile1Format()) : ProfileAnalyzer.getTraceEvents(c.profile1(), c.profile1Format());
        TraceEvent[] t2 = c.profile2IsDir() ? ProfileAnalyzer.getTraceEventsDir(c.profile2(), c.profile2Format()) : ProfileAnalyzer.getTraceEvents(c.profile2(), c.profile2Format());
        final Pair<Long, Map<String, OpStats>> p1 = ProfileAnalyzer.aggregateTraceEvents(t1);
        final Pair<Long, Map<String, OpStats>> p2 = ProfileAnalyzer.aggregateTraceEvents(t2);
        ArrayList<String> l = new ArrayList<String>(c.sortBy() != SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet());
        Collections.sort(l, new Comparator<String>(){

            @Override
            public int compare(String o1, String o2) {
                switch (c.sortBy()) {
                    case PROFILE1_PC: {
                        return -Long.compare(((OpStats)((Map)p1.getSecond()).get(o1)).getSumUs(), ((OpStats)((Map)p1.getSecond()).get(o2)).getSumUs());
                    }
                    case PROFILE2_PC: {
                        return -Long.compare(((OpStats)((Map)p2.getSecond()).get(o1)).getSumUs(), ((OpStats)((Map)p2.getSecond()).get(o2)).getSumUs());
                    }
                    case RATIO: {
                        double m1a = ProfileAnalyzer.meanTime(p1, o1);
                        double m1b = ProfileAnalyzer.meanTime(p1, o2);
                        double m2a = ProfileAnalyzer.meanTime(p2, o1);
                        double m2b = ProfileAnalyzer.meanTime(p2, o2);
                        double ratio1 = m1a / m2a;
                        double ratio2 = m1b / m2b;
                        return -Double.compare(ratio1, ratio2);
                    }
                }
                throw new RuntimeException();
            }
        });
        HashSet<String> set = new HashSet<String>(l);
        StringBuilder sb = new StringBuilder();
        sb.append("1 = ").append(c.p1Name() == null ? "Profile 1" : c.p1Name()).append("\n").append("2 = ").append(c.p2Name() == null ? "Profile 2" : c.p2Name()).append("\n");
        int longestName = 30;
        int longestOpName = 30;
        Map<String, OpStats> stats = c.sortBy() == SortBy.PROFILE2_PC ? p2.getSecond() : p1.getSecond();
        for (String s : l) {
            longestName = Math.max(longestName, s.length() + 1);
            longestOpName = Math.max(longestOpName, stats.get(s).getOpName().length() + 1);
        }
        if (c.format() == null || c.format() == OutputFormat.TEXT) {
            headerFormat = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-16s%-13s%-13s%-14s%-14s%-12s%-12s%-14s%-14s%-10s%-10s%-10s%-10s\n";
            format = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-16.2f%-13.2f%-13.2f%-14d%-14d%-12.2f%-12.2f%-14d%-14d%-10d%-10d%-10.2f%-10.2f\n";
        } else {
            headerFormat = "%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s\n";
            format = "%s,%s,%d,%d,%.2f,%.2f,%.2f,%d,%d,%.2f,%.2f,%d,%d,%d,%d,%.2f,%.2f\n";
        }
        sb.append(String.format(headerFormat, "Op Name", "Op", "Count (1)", "Count (2)", "Mean Ratio 1/2", "Mean (1)", "Mean (2)", "Total uS (1)", "Total uS (2)", "% (1)", "% (2)", "Min (1)", "Min (2)", "Max (1)", "Max (2)", "Std (1)", "Std (2)"));
        for (String s : l) {
            OpStats s1 = p1.getSecond().get(s);
            OpStats s2 = p2.getSecond().get(s);
            if (c.filter() != null && !c.filter().apply(s1, s2).booleanValue()) continue;
            double m1 = s1 == null ? 0.0 : s1.getTimesUs().array().meanNumber().doubleValue();
            double m2 = s2 == null ? 0.0 : s2.getTimesUs().array().meanNumber().doubleValue();
            double ratio = m1 / m2;
            double pc1 = s1 == null ? 0.0 : 100.0 * (double)s1.getSumUs().longValue() / (double)p1.getFirst().longValue();
            double pc2 = s2 == null ? 0.0 : 100.0 * (double)s2.getSumUs().longValue() / (double)p2.getFirst().longValue();
            sb.append(String.format(format, s, s1 != null ? s1.getOpName() : s2.getOpName(), s1 != null ? s1.getCount() : 0, s2 != null ? s2.getCount() : 0, ratio, m1, m2, s1 != null ? s1.getSumUs() : 0L, s2 != null ? s2.getSumUs() : 0L, pc1, pc2, s1 != null ? s1.getTimesUs().array().minNumber().longValue() : 0L, s2 != null ? s2.getTimesUs().array().minNumber().longValue() : 0L, s1 != null ? s1.getTimesUs().array().maxNumber().longValue() : 0L, s2 != null ? s2.getTimesUs().array().maxNumber().longValue() : 0L, s1 != null ? s1.getTimesUs().array().stdNumber().doubleValue() : 0.0, s2 != null ? s2.getTimesUs().array().stdNumber().doubleValue() : 0.0));
        }
        boolean header = false;
        String headerFormat2 = null;
        String format3 = null;
        ArrayList<String> toAppend = null;
        for (String s : c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().keySet() : p2.getSecond().keySet()) {
            if (set.contains(s)) continue;
            Map<String, OpStats> m = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond() : p2.getSecond();
            OpStats st = m.get(s);
            if (c.filter() != null) {
                OpStats other;
                OpStats opStats = other = c.sortBy() == SortBy.PROFILE2_PC ? p1.getSecond().get(s) : p2.getSecond().get(s);
                boolean keep = c.filter().apply(other, st);
                if (!keep) continue;
            }
            if (!header) {
                toAppend = new ArrayList<String>();
                longestName = 30;
                longestOpName = 30;
                for (String s2 : m.keySet()) {
                    longestName = Math.max(longestName, s2.length() + 1);
                    longestOpName = Math.max(longestOpName, m.get(s2).getOpName().length() + 1);
                }
                if (c.format() == null || c.format() == OutputFormat.TEXT) {
                    headerFormat2 = "%-" + longestName + "s%-" + longestOpName + "s%-10s%-10s%-10s%-10s%-10s%-10s\n";
                    format3 = "%-" + longestName + "s%-" + longestOpName + "s%-10d%-10d%-10.2f%-10d%-10d%-10.2f\n";
                } else {
                    headerFormat2 = "%s,%s,%s,%s,%s,%s,%s,%s\n";
                    format3 = "%s,%s,%d,%d,%.2f,%d,%d,%.2f\n";
                }
                sb.append(" *** Operations not in profile ").append(c.sortBy() == SortBy.PROFILE2_PC ? "1" : "2").append(" but in profile ").append(c.sortBy() == SortBy.PROFILE2_PC ? "2" : "1").append(" ***\n");
                sb.append(String.format(headerFormat2, "Op Name", "Op", "Count", "Total uS", "%", "Min", "Max", "Std"));
                header = true;
            }
            long allOpsUs = c.sortBy() == SortBy.PROFILE2_PC ? p1.getFirst().longValue() : p2.getFirst().longValue();
            double pc = 100.0 * (double)st.getTimesUs().array().sumNumber().longValue() / (double)allOpsUs;
            INDArray arr = st.getTimesUs().array();
            long min = arr.minNumber().longValue();
            long max = arr.maxNumber().longValue();
            double std = arr.stdNumber().doubleValue();
            toAppend.add(String.format(format3, s, st.getOpName(), st.getCount(), st.getSumUs(), pc, min, max, std));
        }
        if (toAppend != null) {
            Collections.sort(toAppend);
            for (String s : toAppend) {
                sb.append(s);
            }
        }
        return sb.toString();
    }

    private static double meanTime(Pair<Long, Map<String, OpStats>> p, String name) {
        if (!p.getSecond().containsKey(name)) {
            return 0.0;
        }
        return p.getSecond().get(name).getTimesUs().array().meanNumber().doubleValue();
    }

    static {
        TF_PROFILE_ALIASES.put("_MklSoftmax", "Softmax");
    }

    public static enum OutputFormat {
        TEXT,
        CSV;

    }

    public static enum SortBy {
        PROFILE1_PC,
        PROFILE2_PC,
        RATIO;

    }

    public static enum ProfileFormat {
        SAMEDIFF,
        TENSORFLOW;

    }
}

