/*
 * Decompiled with CFR 0.152.
 */
package net.snowflake.client.jdbc;

import java.io.IOException;
import java.io.InputStream;
import java.nio.channels.ClosedByInterruptException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import net.snowflake.client.core.DataConversionContext;
import net.snowflake.client.core.SFBaseSession;
import net.snowflake.client.core.SFException;
import net.snowflake.client.core.arrow.ArrowResultChunkIndexSorter;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import net.snowflake.client.core.arrow.ArrowVectorConverterUtil;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.SnowflakeResultChunk;
import net.snowflake.client.jdbc.SnowflakeSQLException;
import net.snowflake.client.jdbc.SnowflakeSQLLoggedException;
import net.snowflake.client.jdbc.internal.apache.arrow.memory.BufferAllocator;
import net.snowflake.client.jdbc.internal.apache.arrow.memory.RootAllocator;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.BigIntVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.BitVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.DateDayVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.DecimalVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.FieldVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.Float8Vector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.IntVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.SmallIntVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.TinyIntVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ValueVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.VarBinaryVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.VarCharVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.VectorSchemaRoot;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.complex.StructVector;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.ipc.ArrowStreamReader;
import net.snowflake.client.jdbc.internal.apache.arrow.vector.util.TransferPair;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

public class ArrowResultChunk
extends SnowflakeResultChunk {
    private final ArrayList<List<ValueVector>> batchOfVectors = new ArrayList();
    private static final SFLogger logger = SFLoggerFactory.getLogger(ArrowResultChunk.class);
    private final RootAllocator rootAllocator;
    private boolean enableSortFirstResultChunk;
    private IntVector firstResultChunkSortedIndices;
    private VectorSchemaRoot root;
    private SFBaseSession session;

    public ArrowResultChunk(String url, int rowCount, int colCount, int uncompressedSize, RootAllocator rootAllocator, SFBaseSession session) {
        super(url, rowCount, colCount, uncompressedSize);
        this.rootAllocator = rootAllocator;
        this.session = session;
    }

    private void addBatchData(List<ValueVector> batch) {
        this.batchOfVectors.add(batch);
    }

    public void readArrowStream(InputStream is) throws IOException {
        ArrayList<ValueVector> valueVectors = new ArrayList<ValueVector>();
        try (ArrowStreamReader reader = new ArrowStreamReader(is, (BufferAllocator)this.rootAllocator);){
            this.root = reader.getVectorSchemaRoot();
            while (reader.loadNextBatch()) {
                valueVectors = new ArrayList();
                for (FieldVector f : this.root.getFieldVectors()) {
                    TransferPair t2 = f.getTransferPair(this.rootAllocator);
                    t2.transfer();
                    valueVectors.add(t2.getTo());
                }
                this.addBatchData(valueVectors);
                this.root.clear();
            }
        }
        catch (ClosedByInterruptException cbie) {
            logger.debug("Interrupted when loading Arrow result", cbie);
            valueVectors.forEach(ValueVector::close);
            this.freeData();
        }
        catch (Exception ex) {
            valueVectors.forEach(ValueVector::close);
            this.freeData();
            throw ex;
        }
    }

    @Override
    public void reset() {
        this.freeData();
        this.batchOfVectors.clear();
    }

    @Override
    public long computeNeededChunkMemory() {
        return this.getUncompressedSize();
    }

    @Override
    public void freeData() {
        this.batchOfVectors.forEach((Consumer<List<ValueVector>>)((Consumer<List>)list -> list.forEach(ValueVector::close)));
        this.batchOfVectors.clear();
        if (this.firstResultChunkSortedIndices != null) {
            this.firstResultChunkSortedIndices.close();
        }
        if (this.root != null) {
            this.root.clear();
            this.root = null;
        }
    }

    public ArrowChunkIterator getIterator(DataConversionContext dataConversionContext) {
        return new ArrowChunkIterator(dataConversionContext);
    }

    public static ArrowChunkIterator getEmptyChunkIterator() {
        return new EmptyArrowResultChunk().new ArrowChunkIterator(null);
    }

    public void enableSortFirstResultChunk() {
        this.enableSortFirstResultChunk = true;
    }

    public void mergeBatchesIntoOne() throws SnowflakeSQLException {
        try {
            List<ValueVector> first = this.batchOfVectors.get(0);
            for (int i = 1; i < this.batchOfVectors.size(); ++i) {
                List<ValueVector> batch = this.batchOfVectors.get(i);
                this.mergeBatch(first, batch);
                batch.forEach(ValueVector::close);
            }
            this.batchOfVectors.clear();
            this.batchOfVectors.add(first);
        }
        catch (SFException ex) {
            throw new SnowflakeSQLLoggedException(this.session, "XX000", (int)ErrorCode.INTERNAL_ERROR.getMessageCode(), ex, "Failed to merge first result chunk: " + ex.getLocalizedMessage());
        }
    }

    private void mergeBatch(List<ValueVector> left, List<ValueVector> right) throws SFException {
        for (int i = 0; i < left.size(); ++i) {
            this.mergeVector(left.get(i), right.get(i));
        }
    }

    private void mergeVector(ValueVector left, ValueVector right) throws SFException {
        if (left instanceof StructVector) {
            this.mergeStructVector((StructVector)left, (StructVector)right);
        } else {
            this.mergeNonStructVector(left, right);
        }
    }

    private void mergeStructVector(StructVector left, StructVector right) throws SFException {
        int numOfChildren = left.getChildrenFromFields().size();
        for (int i = 0; i < numOfChildren; ++i) {
            this.mergeNonStructVector(left.getChildrenFromFields().get(i), right.getChildrenFromFields().get(i));
        }
        int offset = left.getValueCount();
        for (int i = 0; i < right.getValueCount(); ++i) {
            if (!right.isNull(i)) continue;
            left.setNull(offset + i);
        }
        left.setValueCount(offset + right.getValueCount());
    }

    private void mergeNonStructVector(ValueVector left, ValueVector right) throws SFException {
        if (left instanceof BigIntVector) {
            BigIntVector bigIntVectorLeft = (BigIntVector)left;
            BigIntVector bigIntVectorRight = (BigIntVector)right;
            int offset = bigIntVectorLeft.getValueCount();
            for (int i = 0; i < bigIntVectorRight.getValueCount(); ++i) {
                if (bigIntVectorRight.isNull(i)) {
                    bigIntVectorLeft.setNull(offset + i);
                    continue;
                }
                bigIntVectorLeft.setSafe(offset + i, bigIntVectorRight.get(i));
            }
            bigIntVectorLeft.setValueCount(offset + bigIntVectorRight.getValueCount());
        } else if (left instanceof BitVector) {
            BitVector bitVectorLeft = (BitVector)left;
            BitVector bitVectorRight = (BitVector)right;
            int offset = bitVectorLeft.getValueCount();
            for (int i = 0; i < bitVectorRight.getValueCount(); ++i) {
                if (bitVectorRight.isNull(i)) {
                    bitVectorLeft.setNull(offset + i);
                    continue;
                }
                try {
                    bitVectorLeft.setSafe(offset + i, bitVectorRight.get(i));
                    continue;
                }
                catch (IndexOutOfBoundsException e) {
                    bitVectorLeft.reAlloc();
                    bitVectorLeft.setSafe(offset + i, bitVectorRight.get(i));
                }
            }
            bitVectorLeft.setValueCount(offset + bitVectorRight.getValueCount());
        } else if (left instanceof DateDayVector) {
            DateDayVector dateDayVectorLeft = (DateDayVector)left;
            DateDayVector dateDayVectorRight = (DateDayVector)right;
            int offset = dateDayVectorLeft.getValueCount();
            for (int i = 0; i < dateDayVectorRight.getValueCount(); ++i) {
                if (dateDayVectorRight.isNull(i)) {
                    dateDayVectorLeft.setNull(offset + i);
                    continue;
                }
                dateDayVectorLeft.setSafe(offset + i, dateDayVectorRight.get(i));
            }
            dateDayVectorLeft.setValueCount(offset + dateDayVectorRight.getValueCount());
        } else if (left instanceof DecimalVector) {
            DecimalVector decimalVectorLeft = (DecimalVector)left;
            DecimalVector decimalVectorRight = (DecimalVector)right;
            int offset = decimalVectorLeft.getValueCount();
            for (int i = 0; i < decimalVectorRight.getValueCount(); ++i) {
                if (decimalVectorRight.isNull(i)) {
                    decimalVectorLeft.setNull(offset + i);
                    continue;
                }
                decimalVectorLeft.setSafe(offset + i, decimalVectorRight.get(i));
            }
            decimalVectorLeft.setValueCount(offset + decimalVectorRight.getValueCount());
        } else if (left instanceof Float8Vector) {
            Float8Vector float8VectorLeft = (Float8Vector)left;
            Float8Vector float8VectorRight = (Float8Vector)right;
            int offset = float8VectorLeft.getValueCount();
            for (int i = 0; i < float8VectorRight.getValueCount(); ++i) {
                if (float8VectorRight.isNull(i)) {
                    float8VectorLeft.setNull(offset + i);
                    continue;
                }
                float8VectorLeft.setSafe(offset + i, float8VectorRight.get(i));
            }
            float8VectorLeft.setValueCount(offset + float8VectorRight.getValueCount());
        } else if (left instanceof IntVector) {
            IntVector intVectorLeft = (IntVector)left;
            IntVector intVectorRight = (IntVector)right;
            int offset = intVectorLeft.getValueCount();
            for (int i = 0; i < intVectorRight.getValueCount(); ++i) {
                if (intVectorRight.isNull(i)) {
                    intVectorLeft.setNull(offset + i);
                    continue;
                }
                intVectorLeft.setSafe(offset + i, intVectorRight.get(i));
            }
            intVectorLeft.setValueCount(offset + intVectorRight.getValueCount());
        } else if (left instanceof SmallIntVector) {
            SmallIntVector smallIntVectorLeft = (SmallIntVector)left;
            SmallIntVector smallIntVectorRight = (SmallIntVector)right;
            int offset = smallIntVectorLeft.getValueCount();
            for (int i = 0; i < smallIntVectorRight.getValueCount(); ++i) {
                if (smallIntVectorRight.isNull(i)) {
                    smallIntVectorLeft.setNull(offset + i);
                    continue;
                }
                smallIntVectorLeft.setSafe(offset + i, smallIntVectorRight.get(i));
            }
            smallIntVectorLeft.setValueCount(offset + smallIntVectorRight.getValueCount());
        } else if (left instanceof TinyIntVector) {
            TinyIntVector tinyIntVectorLeft = (TinyIntVector)left;
            TinyIntVector tinyIntVectorRight = (TinyIntVector)right;
            int offset = tinyIntVectorLeft.getValueCount();
            for (int i = 0; i < tinyIntVectorRight.getValueCount(); ++i) {
                if (tinyIntVectorRight.isNull(i)) {
                    tinyIntVectorLeft.setNull(offset + i);
                    continue;
                }
                tinyIntVectorLeft.setSafe(offset + i, tinyIntVectorRight.get(i));
            }
            tinyIntVectorLeft.setValueCount(offset + tinyIntVectorRight.getValueCount());
        } else if (left instanceof VarBinaryVector) {
            VarBinaryVector varBinaryVectorLeft = (VarBinaryVector)left;
            VarBinaryVector varBinaryVectorRight = (VarBinaryVector)right;
            int offset = varBinaryVectorLeft.getValueCount();
            for (int i = 0; i < varBinaryVectorRight.getValueCount(); ++i) {
                if (varBinaryVectorRight.isNull(i)) {
                    varBinaryVectorLeft.setNull(offset + i);
                    continue;
                }
                varBinaryVectorLeft.setSafe(offset + i, varBinaryVectorRight.get(i));
            }
            varBinaryVectorLeft.setValueCount(offset + varBinaryVectorRight.getValueCount());
        } else if (left instanceof VarCharVector) {
            VarCharVector varCharVectorLeft = (VarCharVector)left;
            VarCharVector varCharVectorRight = (VarCharVector)right;
            int offset = varCharVectorLeft.getValueCount();
            for (int i = 0; i < varCharVectorRight.getValueCount(); ++i) {
                if (varCharVectorRight.isNull(i)) {
                    varCharVectorLeft.setNull(offset + i);
                    continue;
                }
                varCharVectorLeft.setSafe(offset + i, varCharVectorRight.get(i));
            }
            varCharVectorLeft.setValueCount(offset + varCharVectorRight.getValueCount());
        } else {
            throw new SFException(ErrorCode.INTERNAL_ERROR, "Failed to merge vector due to unknown vector type");
        }
    }

    private void sortFirstResultChunk(List<ArrowVectorConverter> converters) throws SnowflakeSQLException {
        try {
            List<ValueVector> firstResultChunk = this.batchOfVectors.get(0);
            ArrowResultChunkIndexSorter sorter = new ArrowResultChunkIndexSorter(firstResultChunk, converters);
            this.firstResultChunkSortedIndices = sorter.sort();
        }
        catch (SFException ex) {
            throw new SnowflakeSQLException((Throwable)ex, "XX000", (int)ErrorCode.INTERNAL_ERROR.getMessageCode(), "Failed to sort first result chunk: " + ex.getLocalizedMessage());
        }
    }

    private boolean sortFirstResultChunkEnabled() {
        return this.enableSortFirstResultChunk;
    }

    private static class EmptyArrowResultChunk
    extends ArrowResultChunk {
        EmptyArrowResultChunk() {
            super("", 0, 0, 0, null, null);
        }

        @Override
        public final long computeNeededChunkMemory() {
            return 0L;
        }

        @Override
        public final void freeData() {
        }
    }

    public class ArrowChunkIterator {
        private int currentRecordBatchIndex = -1;
        private int totalRecordBatch;
        private int currentRowInRecordBatch;
        private int rowCountInCurrentRecordBatch;
        private List<ArrowVectorConverter> currentConverters;
        private DataConversionContext dataConversionContext;

        ArrowChunkIterator(DataConversionContext dataConversionContext) {
            this.totalRecordBatch = ArrowResultChunk.this.batchOfVectors.size();
            this.currentRowInRecordBatch = -1;
            this.rowCountInCurrentRecordBatch = 0;
            this.dataConversionContext = dataConversionContext;
        }

        private List<ArrowVectorConverter> initConverters(List<ValueVector> vectors) throws SnowflakeSQLException {
            ArrayList<ArrowVectorConverter> converters = new ArrayList<ArrowVectorConverter>();
            for (int i = 0; i < vectors.size(); ++i) {
                converters.add(ArrowVectorConverterUtil.initConverter(vectors.get(i), this.dataConversionContext, ArrowResultChunk.this.session, i));
            }
            return converters;
        }

        public boolean next() throws SnowflakeSQLException {
            ++this.currentRowInRecordBatch;
            if (this.currentRowInRecordBatch < this.rowCountInCurrentRecordBatch) {
                return true;
            }
            ++this.currentRecordBatchIndex;
            if (this.currentRecordBatchIndex < this.totalRecordBatch) {
                this.currentRowInRecordBatch = 0;
                if (this.currentRecordBatchIndex == 0 && ArrowResultChunk.this.sortFirstResultChunkEnabled()) {
                    if (ArrowResultChunk.this.batchOfVectors.size() > 1) {
                        ArrowResultChunk.this.mergeBatchesIntoOne();
                        this.totalRecordBatch = 1;
                    }
                    this.rowCountInCurrentRecordBatch = ((ValueVector)((List)ArrowResultChunk.this.batchOfVectors.get(this.currentRecordBatchIndex)).get(0)).getValueCount();
                    this.currentConverters = this.initConverters((List)ArrowResultChunk.this.batchOfVectors.get(this.currentRecordBatchIndex));
                    ArrowResultChunk.this.sortFirstResultChunk(this.currentConverters);
                } else {
                    this.rowCountInCurrentRecordBatch = ((ValueVector)((List)ArrowResultChunk.this.batchOfVectors.get(this.currentRecordBatchIndex)).get(0)).getValueCount();
                    this.currentConverters = this.initConverters((List)ArrowResultChunk.this.batchOfVectors.get(this.currentRecordBatchIndex));
                }
                return true;
            }
            return false;
        }

        public boolean isLast() {
            return this.currentRecordBatchIndex + 1 == this.totalRecordBatch && this.currentRowInRecordBatch + 1 == this.rowCountInCurrentRecordBatch;
        }

        public boolean isAfterLast() {
            return this.currentRecordBatchIndex >= this.totalRecordBatch && this.currentRowInRecordBatch >= this.rowCountInCurrentRecordBatch;
        }

        public ArrowResultChunk getChunk() {
            return ArrowResultChunk.this;
        }

        public ArrowVectorConverter getCurrentConverter(int columnIdx) throws SFException {
            if (columnIdx < 0 || columnIdx >= this.currentConverters.size()) {
                throw new SFException(ErrorCode.COLUMN_DOES_NOT_EXIST, columnIdx + 1);
            }
            return this.currentConverters.get(columnIdx);
        }

        public int getCurrentRowInRecordBatch() {
            if (ArrowResultChunk.this.sortFirstResultChunkEnabled() && this.currentRecordBatchIndex == 0) {
                return ArrowResultChunk.this.firstResultChunkSortedIndices.get(this.currentRowInRecordBatch);
            }
            return this.currentRowInRecordBatch;
        }
    }
}

