001// Licensed under the MIT license. See LICENSE file in the project root for full license information.
002
003package de.bytefish.pgbulkinsert.mapping;
004
005import de.bytefish.pgbulkinsert.function.ToBooleanFunction;
006import de.bytefish.pgbulkinsert.function.ToFloatFunction;
007import de.bytefish.pgbulkinsert.model.ColumnDefinition;
008import de.bytefish.pgbulkinsert.model.TableDefinition;
009import de.bytefish.pgbulkinsert.pgsql.PgBinaryWriter;
010import de.bytefish.pgbulkinsert.pgsql.constants.DataType;
011import de.bytefish.pgbulkinsert.pgsql.constants.ObjectIdentifier;
012import de.bytefish.pgbulkinsert.pgsql.handlers.*;
013import de.bytefish.pgbulkinsert.pgsql.model.geometric.*;
014import de.bytefish.pgbulkinsert.pgsql.model.interval.Interval;
015import de.bytefish.pgbulkinsert.pgsql.model.network.MacAddress;
016import de.bytefish.pgbulkinsert.pgsql.model.range.Range;
017import de.bytefish.pgbulkinsert.util.PostgreSqlUtils;
018
019import java.net.Inet4Address;
020import java.net.Inet6Address;
021import java.time.LocalDate;
022import java.time.LocalDateTime;
023import java.time.LocalTime;
024import java.time.ZonedDateTime;
025import java.util.*;
026import java.util.function.*;
027import java.util.stream.Collectors;
028
029public abstract class AbstractMapping<TEntity> {
030
031    protected boolean usePostgresQuoting;
032
033    protected final IValueHandlerProvider provider;
034
035    protected final TableDefinition table;
036
037    protected final List<ColumnDefinition<TEntity>> columns;
038
039    protected AbstractMapping(String schemaName, String tableName) {
040        this(new ValueHandlerProvider(), schemaName, tableName, false);
041    }
042
043    protected AbstractMapping(String schemaName, String tableName, boolean usePostgresQuoting) {
044        this(new ValueHandlerProvider(), schemaName, tableName, usePostgresQuoting);
045    }
046
047    protected AbstractMapping(IValueHandlerProvider provider, String schemaName, String tableName, boolean usePostgresQuoting) {
048        this.provider = provider;
049        this.table = new TableDefinition(schemaName, tableName);
050        this.usePostgresQuoting = usePostgresQuoting;
051        this.columns = new ArrayList<>();
052    }
053
054    protected void usePostgresQuoting(boolean enabled) {
055        this.usePostgresQuoting = enabled;
056    }
057
058    protected <TElementType, TCollectionType extends Collection<TElementType>> void mapCollection(String columnName, DataType dataType, Function<TEntity, TCollectionType> propertyGetter) {
059
060        final IValueHandler<TElementType> valueHandler = provider.resolve(dataType);
061        final int valueOID = ObjectIdentifier.mapFrom(dataType);
062
063        map(columnName, new CollectionValueHandler<>(valueOID, valueHandler), propertyGetter);
064    }
065
066    protected <TProperty> void map(String columnName, DataType dataType, Function<TEntity, TProperty> propertyGetter) {
067        final IValueHandler<TProperty> valueHandler = provider.resolve(dataType);
068
069        map(columnName, valueHandler, propertyGetter);
070    }
071
072    protected <TProperty> void map(String columnName, IValueHandler<TProperty> valueHandler, Function<TEntity, TProperty> propertyGetter) {
073        addColumn(columnName, (binaryWriter, entity) -> {
074            binaryWriter.write(valueHandler, propertyGetter.apply(entity));
075        });
076    }
077
078    // region Numeric
079
080    protected void mapBoolean(String columnName, Function<TEntity, Boolean> propertyGetter) {
081        map(columnName, DataType.Boolean, propertyGetter);
082    }
083
084    protected void mapBooleanPrimitive(String columnName, ToBooleanFunction<TEntity> propertyGetter) {
085        addColumn(columnName, (binaryWriter, entity) -> {
086            binaryWriter.writeBoolean(propertyGetter.applyAsBoolean(entity));
087        });
088    }
089
090    protected void mapByte(String columnName, Function<TEntity, Number> propertyGetter) {
091        map(columnName, DataType.Char, propertyGetter);
092    }
093
094    protected void mapBytePrimitive(String columnName, ToIntFunction<TEntity> propertyGetter) {
095        addColumn(columnName, (binaryWriter, entity) -> {
096            binaryWriter.writeByte(propertyGetter.applyAsInt(entity));
097        });
098    }
099
100    protected void mapShort(String columnName, Function<TEntity, Number> propertyGetter) {
101        map(columnName, DataType.Int2, propertyGetter);
102    }
103
104    protected void mapShortPrimitive(String columnName, ToIntFunction<TEntity> propertyGetter) {
105        addColumn(columnName, (binaryWriter, entity) -> {
106            binaryWriter.writeShort(propertyGetter.applyAsInt(entity));
107        });
108    }
109
110    protected void mapInteger(String columnName, Function<TEntity, Number> propertyGetter) {
111        map(columnName, DataType.Int4, propertyGetter);
112    }
113
114    protected void mapIntegerPrimitive(String columnName, ToIntFunction<TEntity> propertyGetter) {
115        addColumn(columnName, (binaryWriter, entity) -> {
116            binaryWriter.writeInt(propertyGetter.applyAsInt(entity));
117        });
118    }
119
120    protected void mapNumeric(String columnName, Function<TEntity, Number> propertyGetter) {
121        map(columnName, DataType.Numeric, propertyGetter);
122    }
123
124    protected void mapLong(String columnName, Function<TEntity, Number> propertyGetter) {
125        map(columnName, DataType.Int8, propertyGetter);
126    }
127
128    protected void mapLongPrimitive(String columnName, ToLongFunction<TEntity> propertyGetter) {
129        addColumn(columnName, (binaryWriter, entity) -> {
130            binaryWriter.writeLong(propertyGetter.applyAsLong(entity));
131        });
132    }
133
134    protected void mapFloat(String columnName, Function<TEntity, Number> propertyGetter) {
135        map(columnName, DataType.SinglePrecision, propertyGetter);
136    }
137
138    protected void mapFloatPrimitive(String columnName, ToFloatFunction<TEntity> propertyGetter) {
139        addColumn(columnName, (binaryWriter, entity) -> {
140            binaryWriter.writeFloat(propertyGetter.applyAsFloat(entity));
141        });
142    }
143
144    protected void mapDouble(String columnName, Function<TEntity, Number> propertyGetter) {
145        map(columnName, DataType.DoublePrecision, propertyGetter);
146    }
147
148    protected void mapDoublePrimitive(String columnName, ToDoubleFunction<TEntity> propertyGetter) {
149        addColumn(columnName, (binaryWriter, entity) -> {
150            binaryWriter.writeDouble(propertyGetter.applyAsDouble(entity));
151        });
152    }
153
154    // endregion
155
156    // region Network
157    protected void mapInet4Addr(String columnName, Function<TEntity, Inet4Address> propertyGetter) {
158        map(columnName, DataType.Inet4, propertyGetter);
159    }
160
161    protected void mapInet6Addr(String columnName, Function<TEntity, Inet6Address> propertyGetter) {
162        map(columnName, DataType.Inet6, propertyGetter);
163    }
164
165    protected void mapMacAddress(String columnName, Function<TEntity, MacAddress> propertyGetter) {
166        map(columnName, DataType.MacAddress, propertyGetter);
167    }
168
169    // endregion
170
171    // region Temporal
172
173    protected void mapInterval(String columnName, Function<TEntity, Interval> propertyGetter) {
174        map(columnName, DataType.Interval, propertyGetter);
175    }
176
177    protected void mapDate(String columnName, Function<TEntity, LocalDate> propertyGetter) {
178        map(columnName, DataType.Date, propertyGetter);
179    }
180
181    protected void mapTime(String columnName, Function<TEntity, LocalTime> propertyGetter) {
182        map(columnName, DataType.Time, propertyGetter);
183    }
184
185    protected void mapTimeStamp(String columnName, Function<TEntity, LocalDateTime> propertyGetter) {
186        map(columnName, DataType.Timestamp, propertyGetter);
187    }
188
189    protected void mapTimeStampTz(String columnName, Function<TEntity, ZonedDateTime> propertyGetter) {
190        map(columnName, DataType.TimestampTz, propertyGetter);
191    }
192
193    // endregion
194
195    // region Text
196
197    protected void mapText(String columnName, Function<TEntity, String> propertyGetter) {
198        map(columnName, DataType.Text, propertyGetter);
199    }
200
201    protected void mapVarChar(String columnName, Function<TEntity, String> propertyGetter) {
202        map(columnName, DataType.Text, propertyGetter);
203    }
204
205    // engregion
206
207    // region UUID
208
209    protected void mapUUID(String columnName, Function<TEntity, UUID> propertyGetter) {
210        map(columnName, DataType.Uuid, propertyGetter);
211    }
212
213    // endregion
214
215    // region JSON
216
217    protected void mapJsonb(String columnName, Function<TEntity, String> propertyGetter) {
218        map(columnName, DataType.Jsonb, propertyGetter);
219    }
220
221    // endregion
222
223    // region hstore
224
225    protected void mapHstore(String columnName, Function<TEntity, Map<String, String>> propertyGetter) {
226        map(columnName, DataType.Hstore, propertyGetter);
227    }
228
229    // endregion
230
231    // region Geo
232
233    protected void mapPoint(String columnName, Function<TEntity, Point> propertyGetter) {
234        map(columnName, DataType.Point, propertyGetter);
235    }
236
237    protected void mapBox(String columnName, Function<TEntity, Box> propertyGetter) {
238        map(columnName, DataType.Box, propertyGetter);
239    }
240
241    protected void mapPath(String columnName, Function<TEntity, Path> propertyGetter) {
242        map(columnName, DataType.Path, propertyGetter);
243    }
244
245    protected void mapPolygon(String columnName, Function<TEntity, Polygon> propertyGetter) {
246        map(columnName, DataType.Polygon, propertyGetter);
247    }
248
249    protected void mapLine(String columnName, Function<TEntity, Line> propertyGetter) {
250        map(columnName, DataType.Line, propertyGetter);
251    }
252
253    protected void mapLineSegment(String columnName, Function<TEntity, LineSegment> propertyGetter) {
254        map(columnName, DataType.LineSegment, propertyGetter);
255    }
256
257    protected void mapCircle(String columnName, Function<TEntity, Circle> propertyGetter) {
258        map(columnName, DataType.Circle, propertyGetter);
259    }
260
261    // endregion
262
263    // region Arrays
264
265    protected void mapBooleanArray(String columnName, Function<TEntity, Collection<Boolean>> propertyGetter) {
266        mapCollection(columnName, DataType.Boolean, propertyGetter);
267    }
268
269    protected void mapByteArray(String columnName, Function<TEntity, byte[]> propertyGetter) {
270        map(columnName, DataType.Bytea, propertyGetter);
271    }
272
273    protected <T extends Number> void mapShortArray(String columnName, Function<TEntity, Collection<T>> propertyGetter) {
274        mapCollection(columnName, DataType.Int2, propertyGetter);
275    }
276
277    protected <T extends Number> void mapIntegerArray(String columnName, Function<TEntity, Collection<T>> propertyGetter) {
278        mapCollection(columnName, DataType.Int4, propertyGetter);
279    }
280
281    protected <T extends Number> void mapLongArray(String columnName, Function<TEntity, Collection<T>> propertyGetter) {
282        mapCollection(columnName, DataType.Int8, propertyGetter);
283    }
284
285    protected void mapTextArray(String columnName, Function<TEntity, Collection<String>> propertyGetter) {
286        mapCollection(columnName, DataType.Text, propertyGetter);
287    }
288
289    protected void mapVarCharArray(String columnName, Function<TEntity, Collection<String>> propertyGetter) {
290        mapCollection(columnName, DataType.VarChar, propertyGetter);
291    }
292
293    protected <T extends Number> void mapFloatArray(String columnName, Function<TEntity, Collection<T>> propertyGetter) {
294        mapCollection(columnName, DataType.SinglePrecision, propertyGetter);
295    }
296
297    protected <T extends Number> void mapDoubleArray(String columnName, Function<TEntity, Collection<T>> propertyGetter) {
298        mapCollection(columnName, DataType.DoublePrecision, propertyGetter);
299    }
300
301    protected <T extends Number> void mapNumericArray(String columnName, Function<TEntity, Collection<T>> propertyGetter) {
302        mapCollection(columnName, DataType.Numeric, propertyGetter);
303    }
304
305    protected void mapUUIDArray(String columnName, Function<TEntity, Collection<UUID>> propertyGetter) {
306        mapCollection(columnName, DataType.Uuid, propertyGetter);
307    }
308
309    protected void mapInet4Array(String columnName, Function<TEntity, Collection<Inet4Address>> propertyGetter) {
310        mapCollection(columnName, DataType.Inet4, propertyGetter);
311    }
312
313    protected void mapInet6Array(String columnName, Function<TEntity, Collection<Inet6Address>> propertyGetter) {
314        mapCollection(columnName, DataType.Inet6, propertyGetter);
315    }
316
317    // endregion
318
319    // region Ranges
320
321    protected <TElementType> void mapRange(String columnName, DataType dataType, Function<TEntity, Range<TElementType>> propertyGetter) {
322        final IValueHandler<TElementType> valueHandler = provider.resolve(dataType);
323
324        map(columnName, new RangeValueHandler<>(valueHandler), propertyGetter);
325    }
326
327    protected void mapTsRange(String columnName, Function<TEntity, Range<LocalDateTime>> propertyGetter) {
328        map(columnName, DataType.TsRange, propertyGetter);
329    }
330
331    protected void mapTsTzRange(String columnName, Function<TEntity, Range<ZonedDateTime>> propertyGetter) {
332        map(columnName, DataType.TsTzRange, propertyGetter);
333    }
334
335    protected void mapInt4Range(String columnName, Function<TEntity, Range<Integer>> propertyGetter) {
336        map(columnName, DataType.Int4Range, propertyGetter);
337    }
338
339    protected void mapInt8Range(String columnName, Function<TEntity, Range<Long>> propertyGetter) {
340        map(columnName, DataType.Int8Range, propertyGetter);
341    }
342
343    protected void mapNumRange(String columnName, Function<TEntity, Range<Number>> propertyGetter) {
344        map(columnName, DataType.NumRange, propertyGetter);
345    }
346
347    protected void mapDateRange(String columnName, Function<TEntity, Range<LocalDate>> propertyGetter) {
348        map(columnName, DataType.DateRange, propertyGetter);
349    }
350
351    // endregion
352
353    private void addColumn(String columnName, BiConsumer<PgBinaryWriter, TEntity> action) {
354        columns.add(new ColumnDefinition<>(columnName, action));
355    }
356
357    public List<ColumnDefinition<TEntity>> getColumns() {
358        return columns;
359    }
360
361    public String getCopyCommand() {
362        String commaSeparatedColumns = columns.stream()
363                .map(x -> x.getColumnName())
364                .map(x -> usePostgresQuoting ? PostgreSqlUtils.quoteIdentifier(x) : x)
365                .collect(Collectors.joining(", "));
366
367        return String.format("COPY %1$s(%2$s) FROM STDIN BINARY",
368                table.GetFullyQualifiedTableName(usePostgresQuoting),
369                commaSeparatedColumns);
370    }
371}