/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.transform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.transform.OpPredicate;
import org.nd4j.autodiff.samediff.transform.SubGraph;
import org.nd4j.common.base.Preconditions;

public class SubGraphPredicate
extends OpPredicate {
    protected final OpPredicate root;
    protected Integer inputCount = null;
    protected Integer outputCount = null;
    protected Map<Integer, OpPredicate> opInputMatchPredicates = new HashMap<Integer, OpPredicate>();
    protected Map<Integer, OpPredicate> opInputSubgraphPredicates = new HashMap<Integer, OpPredicate>();

    protected SubGraphPredicate(OpPredicate root) {
        this.root = root;
    }

    @Override
    public boolean matches(SameDiff sameDiff, DifferentialFunction rootFn) {
        int outCount;
        int inCount;
        if (!this.root.matches(sameDiff, rootFn)) {
            return false;
        }
        SDVariable[] inputs = rootFn.args();
        int n = inCount = inputs == null ? 0 : inputs.length;
        if (this.inputCount != null && inCount != this.inputCount) {
            return false;
        }
        SDVariable[] outputs = rootFn.outputVariables();
        int n2 = outCount = outputs == null ? 0 : outputs.length;
        if (this.outputCount != null && outCount != this.outputCount) {
            return false;
        }
        for (Map m : Arrays.asList(this.opInputMatchPredicates, this.opInputSubgraphPredicates)) {
            for (Map.Entry e : m.entrySet()) {
                int inNum = (Integer)e.getKey();
                if (inNum >= inCount) {
                    return false;
                }
                SDVariable in = inputs[inNum];
                DifferentialFunction df = sameDiff.getVariableOutputOp(in.name());
                if (df != null && ((OpPredicate)e.getValue()).matches(sameDiff, df)) continue;
                return false;
            }
        }
        return true;
    }

    public SubGraph getSubGraph(SameDiff sd, DifferentialFunction rootFn) {
        Preconditions.checkState(this.matches(sd, rootFn), "Root function does not match predicate");
        ArrayList<DifferentialFunction> childNodes = new ArrayList<DifferentialFunction>();
        if (!this.opInputSubgraphPredicates.isEmpty()) {
            for (Map.Entry<Integer, OpPredicate> entry : this.opInputSubgraphPredicates.entrySet()) {
                OpPredicate p2 = entry.getValue();
                SDVariable arg = rootFn.arg(entry.getKey());
                DifferentialFunction df = sd.getVariableOutputOp(arg.name());
                if (df == null) continue;
                childNodes.add(df);
                if (!(p2 instanceof SubGraphPredicate)) continue;
                SubGraph sg = ((SubGraphPredicate)p2).getSubGraph(sd, df);
                childNodes.addAll(sg.childNodes);
            }
        }
        SubGraph sg = SubGraph.builder().sameDiff(sd).rootNode(rootFn).childNodes(childNodes).build();
        return sg;
    }

    public static SubGraphPredicate withRoot(@NonNull OpPredicate root) {
        if (root == null) {
            throw new NullPointerException("root is marked non-null but is null");
        }
        return new SubGraphPredicate(root);
    }

    public SubGraphPredicate withInputCount(int inputCount) {
        this.inputCount = inputCount;
        return this;
    }

    public SubGraphPredicate withOutputCount(int outputCount) {
        this.outputCount = outputCount;
        return this;
    }

    public SubGraphPredicate withInputMatching(int inputNum, @NonNull OpPredicate opPredicate) {
        if (opPredicate == null) {
            throw new NullPointerException("opPredicate is marked non-null but is null");
        }
        this.opInputMatchPredicates.put(inputNum, opPredicate);
        return this;
    }

    public SubGraphPredicate withInputSubgraph(int inputNum, @NonNull OpPredicate opPredicate) {
        if (opPredicate == null) {
            throw new NullPointerException("opPredicate is marked non-null but is null");
        }
        this.opInputSubgraphPredicates.put(inputNum, opPredicate);
        return this;
    }
}

