package org.apache.flink.table.runtime.operators.aggregate;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.JoinedRow;
import org.apache.flink.table.dataformat.util.BaseRowUtil;
import org.apache.flink.table.runtime.context.ExecutionContext;
import org.apache.flink.table.runtime.dataview.PerKeyStateDataViewStore;
import org.apache.flink.table.runtime.generated.AggsHandleFunction;
import org.apache.flink.table.runtime.generated.GeneratedAggsHandleFunction;
import org.apache.flink.table.runtime.generated.GeneratedRecordEqualiser;
import org.apache.flink.table.runtime.generated.RecordEqualiser;
import org.apache.flink.table.runtime.operators.bundle.MapBundleFunction;
import org.apache.flink.table.runtime.types.InternalSerializers;
import org.apache.flink.table.runtime.typeutils.BaseRowTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Collector;

/* loaded from: input_file:org/apache/flink/table/runtime/operators/aggregate/MiniBatchGroupAggFunction.class */
public class MiniBatchGroupAggFunction extends MapBundleFunction<BaseRow, List<BaseRow>, BaseRow, BaseRow> {
    private static final long serialVersionUID = 7455939331036508477L;
    private final GeneratedAggsHandleFunction genAggsHandler;
    private final GeneratedRecordEqualiser genRecordEqualiser;
    private final LogicalType[] accTypes;
    private final RowType inputType;
    private final RecordCounter recordCounter;
    private final boolean generateRetraction;
    private transient TypeSerializer<BaseRow> inputRowSerializer;
    private transient JoinedRow resultRow = new JoinedRow();
    private transient AggsHandleFunction function = null;
    private transient RecordEqualiser equaliser = null;
    private transient ValueState<BaseRow> accState = null;

    public MiniBatchGroupAggFunction(GeneratedAggsHandleFunction generatedAggsHandleFunction, GeneratedRecordEqualiser generatedRecordEqualiser, LogicalType[] logicalTypeArr, RowType rowType, int i, boolean z) {
        this.genAggsHandler = generatedAggsHandleFunction;
        this.genRecordEqualiser = generatedRecordEqualiser;
        this.recordCounter = RecordCounter.of(i);
        this.accTypes = logicalTypeArr;
        this.inputType = rowType;
        this.generateRetraction = z;
    }

    @Override // org.apache.flink.table.runtime.operators.bundle.MapBundleFunction
    public void open(ExecutionContext executionContext) throws Exception {
        super.open(executionContext);
        this.function = this.genAggsHandler.newInstance(executionContext.getRuntimeContext().getUserCodeClassLoader());
        this.function.open(new PerKeyStateDataViewStore(executionContext.getRuntimeContext()));
        this.equaliser = this.genRecordEqualiser.newInstance(executionContext.getRuntimeContext().getUserCodeClassLoader());
        this.accState = executionContext.getRuntimeContext().getState(new ValueStateDescriptor("accState", new BaseRowTypeInfo(this.accTypes)));
        this.inputRowSerializer = InternalSerializers.create(this.inputType, executionContext.getRuntimeContext().getExecutionConfig());
        this.resultRow = new JoinedRow();
    }

    @Override // org.apache.flink.table.runtime.operators.bundle.MapBundleFunction
    public List<BaseRow> addInput(@Nullable List<BaseRow> list, BaseRow baseRow) throws Exception {
        List<BaseRow> list2 = list;
        if (list == null) {
            list2 = new ArrayList();
        }
        list2.add(this.inputRowSerializer.copy(baseRow));
        return list2;
    }

    @Override // org.apache.flink.table.runtime.operators.bundle.MapBundleFunction
    public void finishBundle(Map<BaseRow, List<BaseRow>> map, Collector<BaseRow> collector) throws Exception {
        for (Map.Entry<BaseRow, List<BaseRow>> entry : map.entrySet()) {
            BaseRow key = entry.getKey();
            List<BaseRow> value = entry.getValue();
            boolean z = false;
            this.ctx.setCurrentKey(key);
            BaseRow baseRow = (BaseRow) this.accState.value();
            if (baseRow == null) {
                baseRow = this.function.createAccumulators();
                z = true;
            }
            this.function.setAccumulators(baseRow);
            BaseRow value2 = this.function.getValue();
            for (BaseRow baseRow2 : value) {
                if (BaseRowUtil.isAccumulateMsg(baseRow2)) {
                    this.function.accumulate(baseRow2);
                } else {
                    this.function.retract(baseRow2);
                }
            }
            BaseRow value3 = this.function.getValue();
            BaseRow accumulators = this.function.getAccumulators();
            if (this.recordCounter.recordCountIsZero(accumulators)) {
                if (!z) {
                    this.resultRow.replace(key, value2).setHeader((byte) 1);
                    collector.collect(this.resultRow);
                }
                this.accState.clear();
                this.function.cleanup();
            } else {
                this.accState.update(accumulators);
                if (z) {
                    this.resultRow.replace(key, value3).setHeader((byte) 0);
                    collector.collect(this.resultRow);
                } else if (!this.equaliser.equalsWithoutHeader(value2, value3)) {
                    if (this.generateRetraction) {
                        this.resultRow.replace(key, value2).setHeader((byte) 1);
                        collector.collect(this.resultRow);
                    }
                    this.resultRow.replace(key, value3).setHeader((byte) 0);
                    collector.collect(this.resultRow);
                }
            }
        }
    }

    @Override // org.apache.flink.table.runtime.operators.bundle.MapBundleFunction
    public void close() throws Exception {
        if (this.function != null) {
            this.function.close();
        }
    }
}
