/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark.functions;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.sysds.runtime.data.TensorBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import scala.Tuple2;

public class ReplicateTensorFunction
implements PairFlatMapFunction<Tuple2<TensorIndexes, TensorBlock>, TensorIndexes, TensorBlock> {
    private static final long serialVersionUID = 7181347334827684965L;
    private int _byDim;
    private long _numReplicas;

    public ReplicateTensorFunction(int byDim, long numReplicas) {
        this._byDim = byDim;
        this._numReplicas = numReplicas;
    }

    public Iterator<Tuple2<TensorIndexes, TensorBlock>> call(Tuple2<TensorIndexes, TensorBlock> arg0) throws Exception {
        TensorIndexes ix = (TensorIndexes)arg0._1();
        TensorBlock tb = (TensorBlock)arg0._2();
        if (ix.getIndex(this._byDim) != 1L || tb.getNumDims() > this._byDim && tb.getDim(this._byDim) > 1) {
            throw new Exception("Expected dimension " + this._byDim + " to be 1 in ReplicateTensor");
        }
        ArrayList<Tuple2> retVal = new ArrayList<Tuple2>();
        long[] indexes = ix.getIndexes();
        int i = 1;
        while ((long)i <= this._numReplicas) {
            indexes[this._byDim] = i;
            retVal.add(new Tuple2((Object)new TensorIndexes(indexes), (Object)tb));
            ++i;
        }
        return retVal.iterator();
    }
}

