/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.state.api;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.client.ClientUtils;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.client.program.ProgramInvocationException;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.memory.MemoryStateBackend;
import org.apache.flink.state.api.BootstrapTransformation;
import org.apache.flink.state.api.ExistingSavepoint;
import org.apache.flink.state.api.NewSavepoint;
import org.apache.flink.state.api.OperatorTransformation;
import org.apache.flink.state.api.Savepoint;
import org.apache.flink.state.api.functions.BroadcastStateBootstrapFunction;
import org.apache.flink.state.api.functions.KeyedStateBootstrapFunction;
import org.apache.flink.state.api.functions.StateBootstrapFunction;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.BroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.util.AbstractID;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(value=Parameterized.class)
public class SavepointWriterITCase
extends AbstractTestBase {
    private static final String ACCOUNT_UID = "accounts";
    private static final String CURRENCY_UID = "currency";
    private static final String MODIFY_UID = "numbers";
    private static final MapStateDescriptor<String, Double> descriptor = new MapStateDescriptor("currency-rate", Types.STRING, Types.DOUBLE);
    private final StateBackend backend;
    private static final Collection<Account> accounts = Arrays.asList(new Account(1, 100.0), new Account(2, 100.0), new Account(3, 100.0));
    private static final Collection<CurrencyRate> currencyRates = Arrays.asList(new CurrencyRate("USD", 1.0), new CurrencyRate("EUR", 1.3));

    public SavepointWriterITCase(StateBackend backend) throws Exception {
        this.backend = backend;
        miniClusterResource.after();
        miniClusterResource.before();
    }

    @Parameterized.Parameters(name="Savepoint Writer: {0}")
    public static Collection<StateBackend> data() {
        return Arrays.asList(new MemoryStateBackend(), new RocksDBStateBackend((StateBackend)new MemoryStateBackend()));
    }

    @Test
    public void testStateBootstrapAndModification() throws Exception {
        String savepointPath = this.getTempDirPath(new AbstractID().toHexString());
        this.bootstrapState(savepointPath);
        this.validateBootstrap(savepointPath);
        String modifyPath = this.getTempDirPath(new AbstractID().toHexString());
        this.modifySavepoint(savepointPath, modifyPath);
        this.validateModification(modifyPath);
    }

    private void bootstrapState(String savepointPath) throws Exception {
        ExecutionEnvironment bEnv = ExecutionEnvironment.getExecutionEnvironment();
        DataSource accountDataSet = bEnv.fromCollection(accounts);
        BootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataSet)accountDataSet).keyBy((KeySelector & Serializable)acc -> acc.id).transform((KeyedStateBootstrapFunction)new AccountBootstrapper());
        DataSource currencyDataSet = bEnv.fromCollection(currencyRates);
        BootstrapTransformation broadcastTransformation = OperatorTransformation.bootstrapWith((DataSet)currencyDataSet).transform((BroadcastStateBootstrapFunction)new CurrencyBootstrapFunction());
        ((NewSavepoint)((NewSavepoint)Savepoint.create((StateBackend)this.backend, (int)128).withOperator(ACCOUNT_UID, transformation)).withOperator(CURRENCY_UID, broadcastTransformation)).write(savepointPath);
        bEnv.execute("Bootstrap");
    }

    private void validateBootstrap(String savepointPath) throws ProgramInvocationException {
        StreamExecutionEnvironment sEnv = StreamExecutionEnvironment.getExecutionEnvironment();
        sEnv.setStateBackend(this.backend);
        CollectSink.accountList.clear();
        sEnv.fromCollection(accounts).keyBy((KeySelector & Serializable)acc -> acc.id).flatMap((FlatMapFunction)new UpdateAndGetAccount()).uid(ACCOUNT_UID).addSink((SinkFunction)new CollectSink());
        sEnv.fromCollection(currencyRates).connect(sEnv.fromCollection(currencyRates).broadcast(new MapStateDescriptor[]{descriptor})).process((BroadcastProcessFunction)new CurrencyValidationFunction()).uid(CURRENCY_UID).addSink((SinkFunction)new DiscardingSink());
        JobGraph jobGraph = sEnv.getStreamGraph().getJobGraph();
        jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath((String)savepointPath, (boolean)false));
        ClusterClient client = miniClusterResource.getClusterClient();
        ClientUtils.submitJobAndWaitForResult((ClusterClient)client, (JobGraph)jobGraph, (ClassLoader)SavepointWriterITCase.class.getClassLoader());
        Assert.assertEquals((String)"Unexpected output", (long)3L, (long)CollectSink.accountList.size());
    }

    private void modifySavepoint(String savepointPath, String modifyPath) throws Exception {
        ExecutionEnvironment bEnv = ExecutionEnvironment.getExecutionEnvironment();
        DataSource data = bEnv.fromElements((Object[])new Integer[]{1, 2, 3});
        BootstrapTransformation transformation = OperatorTransformation.bootstrapWith((DataSet)data).transform((StateBootstrapFunction)new ModifyProcessFunction());
        ((ExistingSavepoint)((ExistingSavepoint)Savepoint.load((ExecutionEnvironment)bEnv, (String)savepointPath, (StateBackend)this.backend).removeOperator(CURRENCY_UID)).withOperator(MODIFY_UID, transformation)).write(modifyPath);
        bEnv.execute("Modifying");
    }

    private void validateModification(String savepointPath) throws ProgramInvocationException {
        StreamExecutionEnvironment sEnv = StreamExecutionEnvironment.getExecutionEnvironment();
        sEnv.setStateBackend(this.backend);
        CollectSink.accountList.clear();
        SingleOutputStreamOperator stream = sEnv.fromCollection(accounts).keyBy((KeySelector & Serializable)acc -> acc.id).flatMap((FlatMapFunction)new UpdateAndGetAccount()).uid(ACCOUNT_UID);
        stream.addSink((SinkFunction)new CollectSink());
        stream.map((MapFunction & Serializable)acc -> acc.id).map((MapFunction)new StatefulOperator()).uid(MODIFY_UID).addSink((SinkFunction)new DiscardingSink());
        JobGraph jobGraph = sEnv.getStreamGraph().getJobGraph();
        jobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath((String)savepointPath, (boolean)false));
        ClusterClient client = miniClusterResource.getClusterClient();
        ClientUtils.submitJobAndWaitForResult((ClusterClient)client, (JobGraph)jobGraph, (ClassLoader)SavepointWriterITCase.class.getClassLoader());
        Assert.assertEquals((String)"Unexpected output", (long)3L, (long)CollectSink.accountList.size());
    }

    public static class CollectSink
    implements SinkFunction<Account> {
        static Set<Integer> accountList = new ConcurrentSkipListSet<Integer>();

        public void invoke(Account value, SinkFunction.Context context) {
            accountList.add(value.id);
        }
    }

    public static class CurrencyValidationFunction
    extends BroadcastProcessFunction<CurrencyRate, CurrencyRate, Void> {
        public void processElement(CurrencyRate value, BroadcastProcessFunction.ReadOnlyContext ctx, Collector<Void> out) throws Exception {
            Assert.assertEquals((String)"Incorrect currency rate", (double)value.rate, (double)((Double)ctx.getBroadcastState(descriptor).get((Object)value.currency)), (double)1.0E-4);
        }

        public void processBroadcastElement(CurrencyRate value, BroadcastProcessFunction.Context ctx, Collector<Void> out) {
        }
    }

    public static class CurrencyBootstrapFunction
    extends BroadcastStateBootstrapFunction<CurrencyRate> {
        public void processElement(CurrencyRate value, BroadcastStateBootstrapFunction.Context ctx) throws Exception {
            ctx.getBroadcastState(descriptor).put((Object)value.currency, (Object)value.rate);
        }
    }

    public static class StatefulOperator
    extends RichMapFunction<Integer, Integer>
    implements CheckpointedFunction {
        List<Integer> numbers;
        ListState<Integer> state;

        public void open(Configuration parameters) {
            this.numbers = new ArrayList<Integer>();
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.state.clear();
            this.state.addAll(this.numbers);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.state = context.getOperatorStateStore().getUnionListState(new ListStateDescriptor(SavepointWriterITCase.MODIFY_UID, Types.INT));
            if (context.isRestored()) {
                HashSet<Integer> expected = new HashSet<Integer>();
                expected.add(1);
                expected.add(2);
                expected.add(3);
                for (Integer number : (Iterable)this.state.get()) {
                    Assert.assertTrue((String)"Duplicate state", (boolean)expected.contains(number));
                    expected.remove(number);
                }
                Assert.assertTrue((String)("Failed to bootstrap all state elements: " + Arrays.toString(expected.toArray())), (boolean)expected.isEmpty());
            }
        }

        public Integer map(Integer value) {
            return null;
        }
    }

    public static class ModifyProcessFunction
    extends StateBootstrapFunction<Integer> {
        List<Integer> numbers;
        ListState<Integer> state;

        public void open(Configuration parameters) {
            this.numbers = new ArrayList<Integer>();
        }

        public void processElement(Integer value, StateBootstrapFunction.Context ctx) {
            this.numbers.add(value);
        }

        public void snapshotState(FunctionSnapshotContext context) throws Exception {
            this.state.clear();
            this.state.addAll(this.numbers);
        }

        public void initializeState(FunctionInitializationContext context) throws Exception {
            this.state = context.getOperatorStateStore().getUnionListState(new ListStateDescriptor(SavepointWriterITCase.MODIFY_UID, Types.INT));
        }
    }

    public static class UpdateAndGetAccount
    extends RichFlatMapFunction<Account, Account> {
        ValueState<Double> state;

        public void open(Configuration parameters) throws Exception {
            super.open(parameters);
            ValueStateDescriptor descriptor = new ValueStateDescriptor("total", Types.DOUBLE);
            this.state = this.getRuntimeContext().getState(descriptor);
        }

        public void flatMap(Account value, Collector<Account> out) throws Exception {
            Double current = (Double)this.state.value();
            if (current != null) {
                value.amount += current.doubleValue();
            }
            this.state.update((Object)value.amount);
            out.collect((Object)value);
        }
    }

    public static class AccountBootstrapper
    extends KeyedStateBootstrapFunction<Integer, Account> {
        ValueState<Double> state;

        public void open(Configuration parameters) {
            ValueStateDescriptor descriptor = new ValueStateDescriptor("total", Types.DOUBLE);
            this.state = this.getRuntimeContext().getState(descriptor);
        }

        public void processElement(Account value, KeyedStateBootstrapFunction.Context ctx) throws Exception {
            this.state.update((Object)value.amount);
        }
    }

    public static class CurrencyRate {
        public String currency;
        public Double rate;

        CurrencyRate(String currency, double rate) {
            this.currency = currency;
            this.rate = rate;
        }

        public boolean equals(Object obj) {
            return obj instanceof CurrencyRate && ((CurrencyRate)obj).currency.equals(this.currency) && ((CurrencyRate)obj).rate.equals(this.rate);
        }

        public int hashCode() {
            return Objects.hash(this.currency, this.rate);
        }
    }

    public static class Account {
        public int id;
        public double amount;
        public long timestamp;

        Account(int id, double amount) {
            this.id = id;
            this.amount = amount;
            this.timestamp = 1000L;
        }

        public boolean equals(Object obj) {
            return obj instanceof Account && ((Account)obj).id == this.id && ((Account)obj).amount == this.amount;
        }

        public int hashCode() {
            return Objects.hash(this.id, this.amount);
        }
    }
}

