/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.common.primitives;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.AtomicDouble;
import org.nd4j.common.primitives.Pair;

public class Counter<T>
implements Serializable {
    private static final long serialVersionUID = 119L;
    protected ConcurrentHashMap<T, AtomicDouble> map = new ConcurrentHashMap();
    protected AtomicDouble totalCount = new AtomicDouble(0.0f);
    protected AtomicBoolean dirty = new AtomicBoolean(false);

    public double getCount(T element) {
        AtomicDouble t = this.map.get(element);
        if (t == null) {
            return 0.0;
        }
        return t.get();
    }

    public void incrementCount(T element, double inc) {
        AtomicDouble t = this.map.get(element);
        if (t != null) {
            t.addAndGet(inc);
        } else {
            this.map.put(element, new AtomicDouble(inc));
        }
        this.totalCount.addAndGet(inc);
    }

    public void incrementAll(Collection<T> elements, double inc) {
        for (T element : elements) {
            this.incrementCount(element, inc);
        }
    }

    public <T2 extends T> void incrementAll(Counter<T2> other) {
        for (T2 element : other.keySet()) {
            double cnt = other.getCount(element);
            this.incrementCount(element, cnt);
        }
    }

    public double getProbability(T element) {
        if (this.totalCount() <= 0.0) {
            throw new IllegalStateException("Can't calculate probability with empty counter");
        }
        return this.getCount(element) / this.totalCount();
    }

    public double setCount(T element, double count) {
        AtomicDouble t = this.map.get(element);
        if (t != null) {
            double val = t.getAndSet(count);
            this.dirty.set(true);
            return val;
        }
        this.map.put(element, new AtomicDouble(count));
        this.totalCount.addAndGet(count);
        return 0.0;
    }

    public Set<T> keySet() {
        return this.map.keySet();
    }

    public boolean isEmpty() {
        return this.map.size() == 0;
    }

    public Set<Map.Entry<T, AtomicDouble>> entrySet() {
        return this.map.entrySet();
    }

    public List<T> keySetSorted() {
        ArrayList<T> result = new ArrayList<T>();
        PriorityQueue<Pair<T, Double>> pq = this.asPriorityQueue();
        while (!pq.isEmpty()) {
            result.add(pq.poll().getFirst());
        }
        return result;
    }

    public void normalize() {
        for (T key : this.keySet()) {
            this.setCount(key, this.getCount(key) / this.totalCount.get());
        }
        this.rebuildTotals();
    }

    protected void rebuildTotals() {
        this.totalCount.set(0.0);
        for (T key : this.keySet()) {
            this.totalCount.addAndGet(this.getCount(key));
        }
        this.dirty.set(false);
    }

    public double totalCount() {
        if (this.dirty.get()) {
            this.rebuildTotals();
        }
        return this.totalCount.get();
    }

    public double removeKey(T element) {
        AtomicDouble v = this.map.remove(element);
        this.dirty.set(true);
        if (v != null) {
            return v.get();
        }
        return 0.0;
    }

    public T argMax() {
        double maxCount = -1.7976931348623157E308;
        T maxKey = null;
        for (Map.Entry<T, AtomicDouble> entry : this.map.entrySet()) {
            if (!(entry.getValue().get() > maxCount) && maxKey != null) continue;
            maxKey = entry.getKey();
            maxCount = entry.getValue().get();
        }
        return maxKey;
    }

    public void dropElementsBelowThreshold(double threshold) {
        Iterator<T> iterator = this.keySet().iterator();
        while (iterator.hasNext()) {
            T element = iterator.next();
            double val = this.map.get(element).get();
            if (!(val < threshold)) continue;
            iterator.remove();
            this.dirty.set(true);
        }
    }

    public boolean containsElement(T element) {
        return this.map.containsKey(element);
    }

    public void clear() {
        this.map.clear();
        this.totalCount.set(0.0);
        this.dirty.set(false);
    }

    public boolean equals(Object o) {
        if (!(o instanceof Counter)) {
            return false;
        }
        Counter c2 = (Counter)o;
        return this.map.equals(c2.map);
    }

    public int hashCode() {
        return this.map.hashCode();
    }

    public int size() {
        return this.map.size();
    }

    public void keepTopNElements(int N) {
        PriorityQueue<Pair<T, Double>> queue = this.asPriorityQueue();
        this.clear();
        for (int e = 0; e < N; ++e) {
            Pair<T, Double> pair = queue.poll();
            if (pair == null) continue;
            this.incrementCount(pair.getFirst(), pair.getSecond());
        }
    }

    public PriorityQueue<Pair<T, Double>> asPriorityQueue() {
        PriorityQueue<Pair<T, Double>> pq = new PriorityQueue<Pair<T, Double>>(Math.max(1, this.map.size()), new PairComparator());
        for (Map.Entry<T, AtomicDouble> entry : this.map.entrySet()) {
            pq.add(Pair.create(entry.getKey(), entry.getValue().get()));
        }
        return pq;
    }

    public PriorityQueue<Pair<T, Double>> asReversedPriorityQueue() {
        PriorityQueue<Pair<T, Double>> pq = new PriorityQueue<Pair<T, Double>>(Math.max(1, this.map.size()), new ReversedPairComparator());
        for (Map.Entry<T, AtomicDouble> entry : this.map.entrySet()) {
            pq.add(Pair.create(entry.getKey(), entry.getValue().get()));
        }
        return pq;
    }

    public class ReversedPairComparator
    implements Comparator<Pair<T, Double>> {
        @Override
        public int compare(Pair<T, Double> o1, Pair<T, Double> o2) {
            return Double.compare((Double)o1.value, (Double)o2.value);
        }
    }

    public class PairComparator
    implements Comparator<Pair<T, Double>> {
        @Override
        public int compare(Pair<T, Double> o1, Pair<T, Double> o2) {
            return Double.compare((Double)o2.value, (Double)o1.value);
        }
    }
}

