/*
 * Decompiled with CFR 0.152.
 */
package net.ifok.png.compress;

import java.awt.image.BufferedImage;
import java.awt.image.ColorModel;
import java.awt.image.DataBuffer;
import java.awt.image.DataBufferByte;
import java.awt.image.DataBufferInt;
import java.awt.image.DataBufferUShort;
import java.awt.image.DirectColorModel;
import java.awt.image.IndexColorModel;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Set;
import java.util.zip.Deflater;
import net.ifok.png.compress.BitWriter;
import net.ifok.png.compress.D3;
import net.ifok.png.compress.D4;
import net.ifok.png.compress.PngBitReader;
import net.ifok.png.compress.PngChunk;
import net.ifok.png.compress.Quant24;
import net.ifok.png.compress.Quant32;

public class PngEncoder {
    private boolean compress;

    public void write(BufferedImage image, OutputStream outputStream) throws IOException {
        if (this.compress) {
            this.compress8Bit(image, outputStream);
        } else {
            this.compressNormal(image, outputStream);
        }
    }

    public boolean isCompressed() {
        return this.compress;
    }

    public void setCompressed(boolean compress) {
        this.compress = compress;
    }

    private void compressNormal(BufferedImage image, OutputStream outputStream) throws IOException {
        int colType;
        int bh = image.getHeight();
        int bw = image.getWidth();
        ColorModel colorModel = image.getColorModel();
        boolean hasAlpha = colorModel.hasAlpha();
        int pLen = colorModel.getPixelSize();
        int nComp = colorModel.getNumComponents();
        boolean isIndexed = colorModel instanceof IndexColorModel;
        int bitDepth = PngEncoder.calculateBitDepth(pLen, nComp);
        if (isIndexed) {
            colType = 3;
            nComp = 1;
        } else {
            colType = nComp < 3 ? (hasAlpha ? 4 : 0) : (bitDepth < 8 ? (hasAlpha ? 4 : 0) : (hasAlpha ? 6 : 2));
        }
        outputStream.write(PngChunk.SIGNATURE);
        PngChunk chunk = PngChunk.createHeaderChunk(bw, bh, (byte)bitDepth, (byte)colType, (byte)0, (byte)0, (byte)0);
        outputStream.write(chunk.getLength());
        outputStream.write(chunk.getName());
        outputStream.write(chunk.getData());
        outputStream.write(chunk.getCRCValue());
        byte[] pixels = isIndexed && bitDepth != 8 ? PngEncoder.getIndexedPaletteData(image) : PngEncoder.getPixelData(image, bitDepth, nComp, bw, bh);
        if (isIndexed) {
            IndexColorModel indexModel = (IndexColorModel)colorModel;
            int indexModelMapSize = indexModel.getMapSize();
            int[] rgbs = new int[indexModelMapSize];
            indexModel.getRGBs(rgbs);
            if (bitDepth == 8) {
                indexModelMapSize = PngEncoder.reduceIndexMap(indexModelMapSize, rgbs, pixels);
            }
            ByteBuffer bb = ByteBuffer.allocate(indexModelMapSize * 3);
            for (int i = 0; i < indexModelMapSize; ++i) {
                int color = rgbs[i];
                bb.put(new byte[]{(byte)(color >> 16), (byte)(color >> 8), (byte)color});
            }
            chunk = PngChunk.createPaleteChunk(bb.array());
            outputStream.write(chunk.getLength());
            outputStream.write(chunk.getName());
            outputStream.write(chunk.getData());
            outputStream.write(chunk.getCRCValue());
            if (indexModel.getNumComponents() == 4) {
                byte[] trnsBytes = new byte[indexModelMapSize];
                for (int i = 0; i < indexModelMapSize; ++i) {
                    trnsBytes[i] = (byte)(rgbs[i] >> 24);
                }
                chunk = PngChunk.createTrnsChunk(trnsBytes);
                outputStream.write(chunk.getLength());
                outputStream.write(chunk.getName());
                outputStream.write(chunk.getData());
                outputStream.write(chunk.getCRCValue());
            }
        }
        pixels = this.getDeflatedData(pixels);
        chunk = PngChunk.createDataChunk(pixels);
        outputStream.write(chunk.getLength());
        outputStream.write(chunk.getName());
        outputStream.write(chunk.getData());
        outputStream.write(chunk.getCRCValue());
        chunk = PngChunk.createEndChunk();
        outputStream.write(chunk.getLength());
        outputStream.write(chunk.getName());
        outputStream.write(chunk.getData());
        outputStream.write(chunk.getCRCValue());
    }

    private static int reduceIndexMap(int indexModelMapSize, int[] rgbs, byte[] pixels) {
        int i;
        int numColors = 0;
        byte[] indexMap = new byte[indexModelMapSize];
        LinkedHashMap<Integer, Integer> colors = new LinkedHashMap<Integer, Integer>();
        for (i = 0; i < indexModelMapSize; ++i) {
            int color = rgbs[i];
            if (!colors.containsKey(color)) {
                indexMap[i] = (byte)numColors;
                colors.put(color, numColors);
                ++numColors;
                continue;
            }
            indexMap[i] = (byte)((Integer)colors.get(color)).intValue();
        }
        if (numColors < indexModelMapSize) {
            for (i = 0; i < pixels.length; ++i) {
                pixels[i] = indexMap[pixels[i] & 0xFF];
            }
            Set colorSet = colors.keySet();
            int temp = 0;
            Iterator iterator = colorSet.iterator();
            while (iterator.hasNext()) {
                int c = (Integer)iterator.next();
                rgbs[temp++] = c;
            }
        }
        return numColors;
    }

    private static boolean isAlphaUsed(byte[] trnsBytes) {
        for (byte trn : trnsBytes) {
            if (trn == -1) continue;
            return true;
        }
        return false;
    }

    private void compress8Bit(BufferedImage image, OutputStream outputStream) throws IOException {
        int k;
        byte[] colorPalette;
        byte[] qBytes;
        Object[] objs;
        byte[] pixels;
        int type = image.getType();
        int bh = image.getHeight();
        int bw = image.getWidth();
        int dim = bh * bw;
        byte[] trnsBytes = null;
        int[][] argb = null;
        int[][] rgb = null;
        int p = 0;
        switch (type) {
            case 5: {
                int r;
                int g;
                int b;
                int x;
                int[] tempArr;
                int y;
                pixels = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
                rgb = new int[bh][bw];
                for (y = 0; y < bh; ++y) {
                    tempArr = rgb[y];
                    for (x = 0; x < bw; ++x) {
                        b = pixels[p++] & 0xFF;
                        g = pixels[p++] & 0xFF;
                        r = pixels[p++] & 0xFF;
                        tempArr[x] = r << 16 | g << 8 | b;
                    }
                }
                break;
            }
            case 6: {
                int r;
                int g;
                int b;
                int x;
                int[] tempArr;
                int y;
                pixels = ((DataBufferByte)image.getRaster().getDataBuffer()).getData();
                argb = new int[bh][bw];
                for (y = 0; y < bh; ++y) {
                    tempArr = argb[y];
                    for (x = 0; x < bw; ++x) {
                        int a = pixels[p++] & 0xFF;
                        b = pixels[p++] & 0xFF;
                        g = pixels[p++] & 0xFF;
                        r = pixels[p++] & 0xFF;
                        tempArr[x] = a << 24 | r << 16 | g << 8 | b;
                    }
                }
                break;
            }
            case 4: {
                int r;
                int g;
                int b;
                int x;
                int[] tempArr;
                int y;
                int[] intPixels = ((DataBufferInt)image.getRaster().getDataBuffer()).getData();
                rgb = new int[bh][bw];
                for (y = 0; y < bh; ++y) {
                    tempArr = rgb[y];
                    for (x = 0; x < bw; ++x) {
                        int val = intPixels[p++];
                        b = val >> 16 & 0xFF;
                        g = val >> 8 & 0xFF;
                        r = val & 0xFF;
                        tempArr[x] = r << 16 | g << 8 | b;
                    }
                }
                break;
            }
            case 2: {
                int x;
                int[] tempArr;
                int y;
                int[] intPixels = ((DataBufferInt)image.getRaster().getDataBuffer()).getData();
                argb = new int[bh][bw];
                for (y = 0; y < bh; ++y) {
                    tempArr = argb[y];
                    for (x = 0; x < bw; ++x) {
                        tempArr[x] = intPixels[p++];
                    }
                }
                break;
            }
            case 1: {
                int x;
                int[] tempArr;
                int y;
                int[] intPixels = ((DataBufferInt)image.getRaster().getDataBuffer()).getData();
                rgb = new int[bh][bw];
                for (y = 0; y < bh; ++y) {
                    tempArr = rgb[y];
                    for (x = 0; x < bw; ++x) {
                        tempArr[x] = intPixels[p++];
                    }
                }
                break;
            }
            default: {
                this.compressNormal(image, outputStream);
                return;
            }
        }
        byte[] indexedPixels = new byte[dim + bh];
        if (argb != null) {
            objs = PngEncoder.getIndexedMap(argb);
            if (objs != null) {
                qBytes = (byte[])objs[0];
                colorPalette = (byte[])objs[1];
                trnsBytes = (byte[])objs[2];
                if (!PngEncoder.isAlphaUsed(trnsBytes)) {
                    trnsBytes = null;
                }
            } else {
                Quant32 wu = new Quant32();
                Object[] obj = wu.getPalette(argb);
                colorPalette = (byte[])obj[0];
                trnsBytes = (byte[])obj[1];
                qBytes = D4.process(colorPalette, trnsBytes, argb, bh, bw);
                if (!PngEncoder.isAlphaUsed(trnsBytes)) {
                    trnsBytes = null;
                }
            }
            k = 0;
            int z = 0;
            for (int i = 0; i < bh; ++i) {
                indexedPixels[z++] = 0;
                for (int j = 0; j < bw; ++j) {
                    indexedPixels[z++] = qBytes[k++];
                }
            }
        } else {
            objs = PngEncoder.getIndexedMap(rgb);
            if (objs != null) {
                qBytes = (byte[])objs[0];
                colorPalette = (byte[])objs[1];
            } else {
                Quant24 wu = new Quant24();
                colorPalette = wu.getPalette(rgb);
                qBytes = D3.process(colorPalette, rgb, bh, bw);
            }
            k = 0;
            int z = 0;
            for (int i = 0; i < bh; ++i) {
                indexedPixels[z++] = 0;
                for (int j = 0; j < bw; ++j) {
                    indexedPixels[z++] = qBytes[k++];
                }
            }
        }
        int bitDepth = 8;
        int colType = 3;
        outputStream.write(PngChunk.SIGNATURE);
        PngChunk chunk = PngChunk.createHeaderChunk(bw, bh, (byte)bitDepth, (byte)colType, (byte)0, (byte)0, (byte)0);
        outputStream.write(chunk.getLength());
        outputStream.write(chunk.getName());
        outputStream.write(chunk.getData());
        outputStream.write(chunk.getCRCValue());
        pixels = this.getDeflatedData(indexedPixels);
        chunk = PngChunk.createPaleteChunk(colorPalette);
        outputStream.write(chunk.getLength());
        outputStream.write(chunk.getName());
        outputStream.write(chunk.getData());
        outputStream.write(chunk.getCRCValue());
        if (trnsBytes != null) {
            chunk = PngChunk.createTrnsChunk(trnsBytes);
            outputStream.write(chunk.getLength());
            outputStream.write(chunk.getName());
            outputStream.write(chunk.getData());
            outputStream.write(chunk.getCRCValue());
        }
        chunk = PngChunk.createDataChunk(pixels);
        outputStream.write(chunk.getLength());
        outputStream.write(chunk.getName());
        outputStream.write(chunk.getData());
        outputStream.write(chunk.getCRCValue());
        chunk = PngChunk.createEndChunk();
        outputStream.write(chunk.getLength());
        outputStream.write(chunk.getName());
        outputStream.write(chunk.getData());
        outputStream.write(chunk.getCRCValue());
    }

    private static Object[] getIndexedMap(int[][] pixel) {
        int h = pixel.length;
        int w = pixel[0].length;
        int[] colors = new int[256];
        int c = 0;
        int p = 0;
        int t = 0;
        byte[] indexedBytes = new byte[h * w];
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (int y = 0; y < h; ++y) {
            int[] temp = pixel[y];
            for (int x = 0; x < w; ++x) {
                int key = temp[x];
                Integer val = (Integer)map.get(key);
                if (val == null) {
                    if (c > 255) {
                        return null;
                    }
                    map.put(key, c);
                    colors[c] = key;
                    indexedBytes[p++] = (byte)c;
                    ++c;
                    continue;
                }
                indexedBytes[p++] = (byte)val.intValue();
            }
        }
        byte[] palette = new byte[c * 3];
        byte[] trns = new byte[c];
        p = 0;
        for (int i = 0; i < c; ++i) {
            int val = colors[i];
            trns[t++] = (byte)(val >> 24 & 0xFF);
            palette[p++] = (byte)(val >> 16 & 0xFF);
            palette[p++] = (byte)(val >> 8 & 0xFF);
            palette[p++] = (byte)(val & 0xFF);
        }
        return new Object[]{indexedBytes, palette, trns};
    }

    private static byte[] getIndexedPaletteData(BufferedImage buff) throws IOException {
        byte[] pixels = ((DataBufferByte)buff.getRaster().getDataBuffer()).getData();
        int ih = buff.getHeight();
        int len = pixels.length / ih;
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        int k = 0;
        for (int i = 0; i < ih; ++i) {
            bos.write(0);
            byte[] temp = new byte[len];
            System.arraycopy(pixels, k, temp, 0, len);
            bos.write(temp);
            k += len;
        }
        bos.close();
        return bos.toByteArray();
    }

    private static byte[] getPixelData(BufferedImage buff, int bitDepth, int nComp, int bw, int bh) throws IOException {
        ColorModel model = buff.getColorModel();
        switch (bitDepth) {
            case 1: 
            case 2: 
            case 4: {
                byte[] pixels = ((DataBufferByte)buff.getRaster().getDataBuffer()).getData();
                int multi = bitDepth == 1 ? 8 : (bitDepth == 2 ? 4 : 2);
                PngBitReader reader = new PngBitReader(pixels, true);
                ByteArrayOutputStream bos = new ByteArrayOutputStream();
                BitWriter writer = new BitWriter(bos);
                int cc2 = 0;
                int iter = pixels.length * multi;
                for (int i = 0; i < iter; ++i) {
                    if (cc2 == 0) {
                        writer.writeByte((byte)0);
                    }
                    writer.writeBits(reader.getPositive(bitDepth), bitDepth);
                    if (++cc2 != bw) continue;
                    cc2 = 0;
                }
                writer.end();
                bos.flush();
                bos.close();
                return bos.toByteArray();
            }
            case 8: {
                DataBuffer dataBuff = buff.getRaster().getDataBuffer();
                switch (dataBuff.getDataType()) {
                    case 0: {
                        byte[] pixels8 = ((DataBufferByte)buff.getRaster().getDataBuffer()).getData();
                        int pLen = pixels8.length;
                        int col = 0;
                        ByteBuffer bOut = ByteBuffer.allocate(bw * bh * nComp + bh);
                        switch (buff.getType()) {
                            case 5: {
                                for (int p = 0; p < pLen; p += nComp) {
                                    if (col == 0) {
                                        bOut.put((byte)0);
                                    }
                                    byte[] b = new byte[]{pixels8[p + 2], pixels8[p + 1], pixels8[p]};
                                    bOut.put(b);
                                    if (++col != bw) continue;
                                    col = 0;
                                }
                                return bOut.array();
                            }
                            case 6: 
                            case 7: {
                                for (int p = 0; p < pLen; p += nComp) {
                                    if (col == 0) {
                                        bOut.put((byte)0);
                                    }
                                    byte[] b = new byte[]{pixels8[p + 3], pixels8[p + 2], pixels8[p + 1], pixels8[p]};
                                    bOut.put(b);
                                    if (++col != bw) continue;
                                    col = 0;
                                }
                                return bOut.array();
                            }
                        }
                        for (int p = 0; p < pLen; p += nComp) {
                            if (col == 0) {
                                bOut.put((byte)0);
                            }
                            for (int i = 0; i < nComp; ++i) {
                                bOut.put(pixels8[p + i]);
                            }
                            if (++col != bw) continue;
                            col = 0;
                        }
                        return bOut.array();
                    }
                    case 3: {
                        byte[] output;
                        int[] pixInt = ((DataBufferInt)buff.getRaster().getDataBuffer()).getData();
                        int k = 0;
                        int p = 0;
                        if (buff.getType() == 2 || buff.getType() == 3) {
                            output = new byte[bw * bh * 4 + bh];
                            for (int i = 0; i < bh; ++i) {
                                output[k++] = 0;
                                for (int j = 0; j < bw; ++j) {
                                    int val = pixInt[p++];
                                    output[k++] = (byte)(val >> 16);
                                    output[k++] = (byte)(val >> 8);
                                    output[k++] = (byte)val;
                                    output[k++] = (byte)(val >> 24);
                                }
                            }
                        } else if (buff.getType() == 1) {
                            output = new byte[bw * bh * 3 + bh];
                            for (int i = 0; i < bh; ++i) {
                                output[k++] = 0;
                                for (int j = 0; j < bw; ++j) {
                                    int val = pixInt[p++];
                                    output[k++] = (byte)(val >> 16);
                                    output[k++] = (byte)(val >> 8);
                                    output[k++] = (byte)val;
                                }
                            }
                        } else if (buff.getType() == 4) {
                            output = new byte[bw * bh * 3 + bh];
                            for (int i = 0; i < bh; ++i) {
                                output[k++] = 0;
                                for (int j = 0; j < bw; ++j) {
                                    int val = pixInt[p++];
                                    output[k++] = (byte)val;
                                    output[k++] = (byte)(val >> 8);
                                    output[k++] = (byte)(val >> 16);
                                }
                            }
                        } else if (model instanceof DirectColorModel) {
                            DirectColorModel dm = (DirectColorModel)model;
                            long rMask = PngEncoder.getMaskValue(dm.getRedMask());
                            long gMask = PngEncoder.getMaskValue(dm.getGreenMask());
                            long bMask = PngEncoder.getMaskValue(dm.getBlueMask());
                            long aMask = PngEncoder.getMaskValue(dm.getAlphaMask());
                            output = new byte[bw * bh * 4 + bh];
                            for (int i = 0; i < bh; ++i) {
                                output[k++] = 0;
                                for (int j = 0; j < bw; ++j) {
                                    int val = pixInt[p++];
                                    output[k++] = (byte)(val >> (int)rMask);
                                    output[k++] = (byte)(val >> (int)gMask);
                                    output[k++] = (byte)(val >> (int)bMask);
                                    output[k++] = (byte)(val >> (int)aMask);
                                }
                            }
                        } else {
                            ByteBuffer out = ByteBuffer.allocate(bw * bh * nComp + bh);
                            int clm = 0;
                            for (int i : pixInt) {
                                if (clm == 0) {
                                    out.put((byte)0);
                                }
                                byte[] t = PngChunk.intToBytes(i);
                                switch (nComp) {
                                    case 4: {
                                        out.put(new byte[]{t[1], t[2], t[3], t[0]});
                                        break;
                                    }
                                    case 3: {
                                        out.put(new byte[]{t[1], t[2], t[3]});
                                        break;
                                    }
                                    case 2: {
                                        out.put(new byte[]{t[2], t[3]});
                                        break;
                                    }
                                    case 1: {
                                        out.put(t[3]);
                                    }
                                }
                                if (++clm != bw) continue;
                                clm = 0;
                            }
                            return out.array();
                        }
                        return output;
                    }
                }
            }
            case 16: {
                short[] shortPixels = ((DataBufferUShort)buff.getRaster().getDataBuffer()).getData();
                ByteBuffer bos16 = ByteBuffer.allocate(shortPixels.length * 2 + bh);
                int scol = 0;
                for (int p = 0; p < shortPixels.length; p += nComp) {
                    if (scol == 0) {
                        bos16.put((byte)0);
                    }
                    for (int i = 0; i < nComp; ++i) {
                        bos16.putShort(shortPixels[p + i]);
                    }
                    if (++scol != bw) continue;
                    scol = 0;
                }
                return bos16.array();
            }
        }
        return null;
    }

    private static int getMaskValue(int mask) {
        switch (mask) {
            case 255: {
                return 0;
            }
            case 65280: {
                return 8;
            }
            case 0xFF0000: {
                return 16;
            }
        }
        return 24;
    }

    private static int calculateBitDepth(int pixelBits, int nComp) {
        if (pixelBits < 8) {
            return pixelBits;
        }
        int c = pixelBits / nComp;
        if (c == 8 || c == 16) {
            return c;
        }
        return 8;
    }

    private byte[] getDeflatedData(byte[] pixels) throws IOException {
        Deflater deflater = this.compress ? new Deflater(9) : new Deflater(1);
        deflater.setInput(pixels);
        int min = Math.min(pixels.length / 2, 4096);
        ByteArrayOutputStream outputStream = new ByteArrayOutputStream(min);
        deflater.finish();
        byte[] buffer = new byte[min];
        while (!deflater.finished()) {
            int count = deflater.deflate(buffer);
            outputStream.write(buffer, 0, count);
        }
        deflater.end();
        outputStream.close();
        return outputStream.toByteArray();
    }
}

