/*
 * Decompiled with CFR 0.152.
 */
package io.fluxzero.proxy;

import io.fluxzero.common.Guarantee;
import io.fluxzero.common.MessageType;
import io.fluxzero.common.ObjectUtils;
import io.fluxzero.common.Registration;
import io.fluxzero.common.api.Data;
import io.fluxzero.common.api.DisconnectEvent;
import io.fluxzero.common.api.Metadata;
import io.fluxzero.common.api.SerializedMessage;
import io.fluxzero.common.serialization.JsonUtils;
import io.fluxzero.javaclient.Fluxzero;
import io.fluxzero.javaclient.configuration.client.Client;
import io.fluxzero.javaclient.publishing.client.GatewayClient;
import io.fluxzero.javaclient.tracking.ConsumerConfiguration;
import io.fluxzero.javaclient.tracking.IndexUtils;
import io.fluxzero.javaclient.tracking.client.DefaultTracker;
import io.fluxzero.javaclient.web.WebRequest;
import jakarta.websocket.CloseReason;
import jakarta.websocket.Endpoint;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.PongMessage;
import jakarta.websocket.Session;
import java.beans.ConstructorProperties;
import java.io.OutputStream;
import java.io.Writer;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WebsocketEndpoint
extends Endpoint {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(WebsocketEndpoint.class);
    static final String metadataPrefix = "_metadata:";
    static final String clientIdKey = "_clientId";
    static final String trackerIdKey = "_trackerId";
    private final Map<String, Session> openSessions = new ConcurrentHashMap<String, Session>();
    private final Client client;
    private final GatewayClient requestGateway;
    private final AtomicBoolean started = new AtomicBoolean();
    private volatile Registration registration;

    public WebsocketEndpoint(Client client) {
        this.client = client;
        this.requestGateway = client.getGatewayClient(MessageType.WEBREQUEST);
    }

    public void onOpen(Session session, EndpointConfig config) {
        this.ensureStarted();
        this.openSessions.put(session.getId(), session);
        session.addMessageHandler(byte[].class, bytes -> this.sendRequest(session, "WS_MESSAGE", (byte[])bytes));
        session.addMessageHandler(String.class, s -> this.sendRequest(session, "WS_MESSAGE", s.getBytes(StandardCharsets.UTF_8)));
        session.addMessageHandler(PongMessage.class, pong -> this.sendRequest(session, "WS_PONG", ObjectUtils.getBytes((ByteBuffer)pong.getApplicationData())));
        this.sendRequest(session, "WS_OPEN", null);
    }

    public void onClose(Session session, CloseReason closeReason) {
        this.openSessions.remove(session.getId());
        this.sendRequest(session, "WS_CLOSE", String.valueOf(closeReason.getCloseCode().getCode()).getBytes(StandardCharsets.UTF_8));
    }

    public void onError(Session session, Throwable error) {
        log.warn("Error in session {}", (Object)session.getId(), (Object)error);
    }

    protected void sendRequest(Session session, String method, byte[] payload) {
        Metadata metadata = this.getContext(session).metadata().with((Object)"method", (Object)method);
        SerializedMessage request = new SerializedMessage(new Data((Object)payload, null, 0, "unknown"), metadata, Fluxzero.generateId(), Long.valueOf(Fluxzero.currentClock().millis()));
        request.setSource(this.client.id());
        request.setTarget(this.getContext(session).trackerId());
        this.requestGateway.append(Guarantee.SENT, new SerializedMessage[]{request});
    }

    protected void handleResultMessages(List<SerializedMessage> resultMessages) {
        resultMessages.forEach(m -> {
            Session session;
            String sessionId = WebRequest.getSocketSessionId((Metadata)m.getMetadata());
            if (sessionId != null && (session = this.openSessions.get(sessionId)) != null && session.isOpen()) {
                try {
                    switch (m.getMetadata().getOrDefault((Object)"function", "message")) {
                        case "message": {
                            this.sendMessage((SerializedMessage)m, session);
                            break;
                        }
                        case "ping": {
                            this.sendPing((SerializedMessage)m, session);
                            break;
                        }
                        case "close": {
                            this.sendClose((SerializedMessage)m, session);
                            break;
                        }
                    }
                }
                catch (Exception e) {
                    log.warn("Failed to send websocket result to client (session {})", (Object)session.getId(), (Object)e);
                }
            }
        });
    }

    private void sendMessage(SerializedMessage m, Session session) {
        block15: {
            if (byte[].class.getName().equals(m.getData().getType())) {
                try (OutputStream outputStream = session.getBasicRemote().getSendStream();){
                    outputStream.write((byte[])m.getData().getValue());
                    break block15;
                }
            }
            try (Writer writer = session.getBasicRemote().getSendWriter();){
                writer.write(new String((byte[])m.getData().getValue(), StandardCharsets.UTF_8));
            }
        }
    }

    private void sendPing(SerializedMessage m, Session session) {
        session.getBasicRemote().sendPing(ByteBuffer.wrap((byte[])m.getData().getValue()));
    }

    private void sendClose(SerializedMessage m, Session session) {
        session.close(new CloseReason(CloseReason.CloseCodes.getCloseCode((int)Integer.parseInt(new String((byte[])m.getData().getValue(), StandardCharsets.UTF_8))), null));
    }

    protected void handleDisconnects(List<SerializedMessage> resultMessages) {
        Set clientIds = resultMessages.stream().map(m -> (DisconnectEvent)JsonUtils.fromJson((byte[])((byte[])m.getData().getValue()), DisconnectEvent.class)).map(DisconnectEvent::getClientId).collect(Collectors.toSet());
        this.openSessions.values().stream().filter(s -> clientIds.contains(this.getContext((Session)s).clientId())).forEach(session -> {
            try {
                if (session.isOpen()) {
                    session.close(new CloseReason((CloseReason.CloseCode)CloseReason.CloseCodes.GOING_AWAY, "going away"));
                }
            }
            catch (Exception e) {
                log.warn("Failed to close session {}", (Object)session.getId(), (Object)e);
            }
        });
    }

    protected SessionContext getContext(Session session) {
        return (SessionContext)session.getUserProperties().computeIfAbsent("context", c -> {
            SessionContext.SessionContextBuilder contextBuilder = SessionContext.builder();
            LinkedHashMap map = new LinkedHashMap();
            session.getRequestParameterMap().forEach((k, v) -> {
                if (k.startsWith(metadataPrefix)) {
                    String name = k.substring(metadataPrefix.length());
                    map.put(name, (String)v.getFirst());
                } else if (k.equals(trackerIdKey)) {
                    contextBuilder.trackerId((String)v.getFirst());
                } else if (k.equals(clientIdKey)) {
                    contextBuilder.clientId((String)v.getFirst());
                }
            });
            contextBuilder.metadata(Metadata.of(map).with((Object)"sessionId", (Object)session.getId()));
            return contextBuilder.build();
        });
    }

    protected void ensureStarted() {
        if (this.started.compareAndSet(false, true)) {
            this.registration = DefaultTracker.start(this::handleResultMessages, (MessageType)MessageType.WEBRESPONSE, (ConsumerConfiguration)ConsumerConfiguration.builder().name(String.format("%s_%s", this.client.name(), "$websocket-handler")).ignoreSegment(true).clientControlledIndex(true).filterMessageTarget(true).minIndex(Long.valueOf(IndexUtils.indexFromTimestamp((Instant)Fluxzero.currentTime().minusSeconds(2L)))).build(), (Client)this.client).merge(DefaultTracker.start(this::handleDisconnects, (MessageType)MessageType.METRICS, (ConsumerConfiguration)ConsumerConfiguration.builder().name(String.format("%s_%s", this.client.name(), "$websocket-handler")).ignoreSegment(true).clientControlledIndex(true).typeFilter(Pattern.quote(DisconnectEvent.class.getName())).minIndex(Long.valueOf(IndexUtils.indexFromTimestamp((Instant)Fluxzero.currentTime().minusSeconds(1L)))).build(), (Client)this.client));
        }
    }

    public void shutDown() {
        if (this.started.compareAndSet(true, false) && this.registration != null) {
            this.registration.cancel();
            this.openSessions.values().removeIf(s -> {
                try {
                    if (s.isOpen()) {
                        s.close(new CloseReason((CloseReason.CloseCode)CloseReason.CloseCodes.GOING_AWAY, "Redeployment"));
                    }
                }
                catch (Throwable e) {
                    log.warn("Failed to close session when leaving: {}", (Object)s.getId(), (Object)e);
                }
                return true;
            });
        }
    }

    @ConstructorProperties(value={"client", "requestGateway", "registration"})
    @Generated
    public WebsocketEndpoint(Client client, GatewayClient requestGateway, Registration registration) {
        this.client = client;
        this.requestGateway = requestGateway;
        this.registration = registration;
    }

    protected record SessionContext(Metadata metadata, String clientId, String trackerId) {
        @Generated
        public static SessionContextBuilder builder() {
            return new SessionContextBuilder();
        }

        @Generated
        public static class SessionContextBuilder {
            @Generated
            private Metadata metadata;
            @Generated
            private String clientId;
            @Generated
            private String trackerId;

            @Generated
            SessionContextBuilder() {
            }

            @Generated
            public SessionContextBuilder metadata(Metadata metadata) {
                this.metadata = metadata;
                return this;
            }

            @Generated
            public SessionContextBuilder clientId(String clientId) {
                this.clientId = clientId;
                return this;
            }

            @Generated
            public SessionContextBuilder trackerId(String trackerId) {
                this.trackerId = trackerId;
                return this;
            }

            @Generated
            public SessionContext build() {
                return new SessionContext(this.metadata, this.clientId, this.trackerId);
            }

            @Generated
            public String toString() {
                return "WebsocketEndpoint.SessionContext.SessionContextBuilder(metadata=" + String.valueOf(this.metadata) + ", clientId=" + this.clientId + ", trackerId=" + this.trackerId + ")";
            }
        }
    }
}

