001// Licensed under the MIT license. See LICENSE file in the project root for full license information.
002
003package de.bytefish.pgbulkinsert.pgsql.handlers;
004
005import de.bytefish.pgbulkinsert.util.BigDecimalUtils;
006
007import java.io.DataOutputStream;
008import java.math.BigDecimal;
009import java.math.BigInteger;
010import java.util.ArrayList;
011import java.util.List;
012
013/**
014 * The Algorithm for turning a BigDecimal into a Postgres Numeric is heavily inspired by the Intermine Implementation:
015 * <p>
016 * https://github.com/intermine/intermine/blob/master/intermine/objectstore/main/src/org/intermine/sql/writebatch/BatchWriterPostgresCopyImpl.java
017 * <p>
018 *  please see struct definition of @{link NumericVar} for numeric data type byte structure at:
019 *  https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/numeric.c
020 */
021public class BigDecimalValueHandler<T extends Number> extends BaseValueHandler<T> {
022
023    private static final int DECIMAL_DIGITS = 4;
024    private static final BigInteger TEN_THOUSAND = new BigInteger("10000");
025
026    @Override
027    protected void internalHandle(final DataOutputStream buffer, final T value) throws Exception {
028        final BigDecimal tmpValue = getNumericAsBigDecimal(value);
029
030        // Number of fractional digits:
031        final int fractionDigits = tmpValue.scale();
032
033        // Number of Fraction Groups:
034        final int fractionGroups = fractionDigits > 0 ? (fractionDigits + 3) / 4 : 0;
035
036        final List<Integer> digits = digits(tmpValue);
037
038        buffer.writeInt(8 + (2 * digits.size()));
039        buffer.writeShort(digits.size());
040        buffer.writeShort(digits.size() - fractionGroups - 1);
041        buffer.writeShort(tmpValue.signum() == 1 ? 0x0000 : 0x4000);
042        buffer.writeShort(fractionDigits > 0 ? fractionDigits : 0);
043
044        // Now write each digit:
045        for (int pos = digits.size() - 1; pos >= 0; pos--) {
046            final int valueToWrite = digits.get(pos);
047            buffer.writeShort(valueToWrite);
048        }
049    }
050
051    @Override
052    public int getLength(final T value) {
053        final List<Integer> digits = digits(getNumericAsBigDecimal(value));
054        return (8 + (2 * digits.size()));
055    }
056
057    private static BigDecimal getNumericAsBigDecimal(final Number source) {
058        if (source instanceof BigDecimal) {
059            return (BigDecimal) source;
060        }
061        if (source instanceof BigInteger) {
062            return new BigDecimal((BigInteger) source);
063        }
064        return BigDecimalUtils.toBigDecimal(source.doubleValue());
065    }
066
067    private List<Integer> digits(final BigDecimal value) {
068        BigInteger unscaledValue = value.unscaledValue();
069
070        if (value.signum() == -1) {
071            unscaledValue = unscaledValue.negate();
072        }
073
074        final List<Integer> digits = new ArrayList<>();
075
076        if (value.scale() > 0) {
077            // The scale needs to be a multiple of 4:
078            int scaleRemainder = value.scale() % 4;
079
080            // Scale the first value:
081            if (scaleRemainder != 0) {
082                final BigInteger[] result = unscaledValue.divideAndRemainder(BigInteger.TEN.pow(scaleRemainder));
083                final int digit = result[1].intValue() * (int) Math.pow(10, DECIMAL_DIGITS - scaleRemainder);
084                digits.add(digit);
085                unscaledValue = result[0];
086            }
087
088            while (!unscaledValue.equals(BigInteger.ZERO)) {
089                final BigInteger[] result = unscaledValue.divideAndRemainder(TEN_THOUSAND);
090                digits.add(result[1].intValue());
091                unscaledValue = result[0];
092            }
093        } else {
094            BigInteger originalValue = unscaledValue.multiply(BigInteger.TEN.pow(Math.abs(value.scale())));
095            while (!originalValue.equals(BigInteger.ZERO)) {
096                final BigInteger[] result = originalValue.divideAndRemainder(TEN_THOUSAND);
097                digits.add(result[1].intValue());
098                originalValue = result[0];
099            }
100        }
101
102        return digits;
103    }
104}