package org.nd4j.linalg.util;

import edu.umd.cs.findbugs.annotations.Nullable;
import java.util.Arrays;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/util/DeviceLocalNDArray.class */
public class DeviceLocalNDArray extends DeviceLocal<INDArray> {
    private static final Logger log = LoggerFactory.getLogger(DeviceLocalNDArray.class);

    public DeviceLocalNDArray() {
        this(false);
    }

    public DeviceLocalNDArray(boolean z) {
        super(z);
    }

    public DeviceLocalNDArray(INDArray iNDArray) {
        this(iNDArray, false);
    }

    public DeviceLocalNDArray(INDArray iNDArray, boolean z) {
        super(z);
        broadcast(iNDArray);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.util.DeviceLocal
    @Nullable
    public synchronized INDArray get() {
        Integer deviceForCurrentThread = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int i = this.updatesMap.get(deviceForCurrentThread.intValue()).get();
        if (i >= 0 && i != deviceForCurrentThread.intValue()) {
            INDArray create = Nd4j.create(this.delayedArray.dataType(), this.delayedArray.shape(), this.delayedArray.stride(), this.delayedArray.ordering());
            Nd4j.getMemoryManager().memcpy(create.data(), this.delayedArray.data());
            this.backingMap.put(deviceForCurrentThread, create);
            this.updatesMap.get(deviceForCurrentThread.intValue()).set(deviceForCurrentThread.intValue());
            boolean z = true;
            int i2 = 0;
            while (true) {
                if (i2 >= numberOfDevices) {
                    break;
                }
                if (this.updatesMap.get(i2).get() != i2) {
                    z = false;
                    break;
                }
                i2++;
            }
            if (z) {
                this.delayedArray = null;
            }
        }
        return get(deviceForCurrentThread.intValue());
    }

    public synchronized void broadcast(INDArray iNDArray) {
        if (iNDArray == null) {
            return;
        }
        Preconditions.checkArgument((iNDArray.isView() && iNDArray.elementWiseStride() == 1) ? false : true, "View can't be used in DeviceLocalNDArray");
        Nd4j.getExecutioner().commit();
        ProfilerConfig config = OpProfiler.getInstance().getConfig();
        boolean isCheckLocality = config.isCheckLocality();
        if (isCheckLocality) {
            config.setCheckLocality(false);
        }
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        Integer deviceForCurrentThread = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        if (this.delayedMode) {
            set(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), iNDArray);
            this.delayedArray = iNDArray.dup(iNDArray.ordering()).detach();
            for (int i = 0; i < numberOfDevices; i++) {
                if (i != deviceForCurrentThread.intValue()) {
                    this.updatesMap.get(i).set(deviceForCurrentThread.intValue());
                }
            }
        } else {
            for (int i2 = 0; i2 < numberOfDevices; i2++) {
                if (deviceForCurrentThread.intValue() == i2) {
                    set(i2, iNDArray.detach());
                } else {
                    set(i2, Nd4j.getAffinityManager().replicateToDevice(Integer.valueOf(i2), iNDArray));
                }
            }
        }
        config.setCheckLocality(isCheckLocality);
    }

    public synchronized void update(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        Preconditions.checkArgument((iNDArray.isView() && iNDArray.elementWiseStride() == 1) ? false : true, "View can't be used in DeviceLocalNDArray");
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        Integer deviceForCurrentThread = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        boolean z = false;
        if (!Arrays.equals(((INDArray) this.backingMap.get(deviceForCurrentThread)).shapeInfoJava(), iNDArray.shapeInfoJava())) {
            broadcast(iNDArray);
            return;
        }
        for (int i = 0; i < numberOfDevices; i++) {
            ReentrantReadWriteLock reentrantReadWriteLock = this.locksMap.get(i);
            try {
                reentrantReadWriteLock.writeLock().lock();
                INDArray iNDArray2 = (INDArray) this.backingMap.get(Integer.valueOf(i));
                if (iNDArray2 == null) {
                    if (!z) {
                        this.delayedArray = iNDArray.dup(iNDArray.ordering()).detach();
                        z = true;
                    }
                    this.updatesMap.get(i).set(deviceForCurrentThread.intValue());
                    reentrantReadWriteLock.writeLock().unlock();
                } else {
                    Nd4j.getMemoryManager().memcpy(iNDArray2.data(), iNDArray.data());
                    Nd4j.getExecutioner().commit();
                    reentrantReadWriteLock.writeLock().unlock();
                }
            } catch (Throwable th) {
                reentrantReadWriteLock.writeLock().unlock();
                throw th;
            }
        }
    }
}
