package pl.codewise.commons.aws.cqrs.discovery;

import com.amazonaws.services.autoscaling.AmazonAutoScaling;
import com.amazonaws.services.autoscaling.model.AutoScalingInstanceDetails;
import com.amazonaws.services.autoscaling.model.DescribeAutoScalingInstancesRequest;
import com.amazonaws.services.autoscaling.model.DescribeAutoScalingInstancesResult;
import com.amazonaws.services.ec2.AmazonEC2;
import com.amazonaws.services.ec2.model.DescribeInstancesRequest;
import com.amazonaws.services.ec2.model.DescribeSpotInstanceRequestsRequest;
import com.amazonaws.services.ec2.model.EbsInstanceBlockDevice;
import com.amazonaws.services.ec2.model.GroupIdentifier;
import com.amazonaws.services.ec2.model.IamInstanceProfile;
import com.amazonaws.services.ec2.model.Instance;
import com.amazonaws.services.ec2.model.InstanceBlockDeviceMapping;
import com.amazonaws.services.ec2.model.InstanceNetworkInterface;
import com.amazonaws.services.ec2.model.InstanceNetworkInterfaceAssociation;
import com.amazonaws.services.ec2.model.InstanceNetworkInterfaceAttachment;
import com.amazonaws.services.ec2.model.InstancePrivateIpAddress;
import com.amazonaws.services.ec2.model.InstanceStateName;
import com.amazonaws.services.ec2.model.SpotInstanceRequest;
import com.amazonaws.services.ec2.model.Tag;
import com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import pl.codewise.commons.aws.cqrs.model.AwsInstance;
import pl.codewise.commons.aws.cqrs.model.AwsNetworkInterface;
import pl.codewise.commons.aws.cqrs.model.AwsNetworkInterfaceAttachment;
import pl.codewise.commons.aws.cqrs.model.AwsPrivateIpAddress;
import pl.codewise.commons.aws.cqrs.model.AwsPrivateIpAddressAssociation;
import pl.codewise.commons.aws.cqrs.model.ec2.AwsInstanceBlockDeviceMapping;
import pl.codewise.commons.aws.cqrs.model.ec2.AwsInstanceEbs;
import pl.codewise.commons.aws.cqrs.model.ec2.AwsResourceTag;
import pl.codewise.commons.aws.cqrs.model.ec2.autoscaling.AwsAutoScalingDetails;
import pl.codewise.commons.aws.cqrs.model.ec2.sg.AwsSecurityGroup;
import pl.codewise.commons.aws.cqrs.model.ec2.spot.AwsSpotRequestDetails;
import pl.codewise.commons.aws.cqrs.utils.Awaitilities;

import javax.annotation.Nullable;
import java.time.Duration;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static com.google.common.collect.Lists.partition;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;

public class Ec2Discovery {

    private static final Logger log = LoggerFactory.getLogger(Ec2Discovery.class);
    private static final int MAX_NUMBER_OF_INSTANCE_IDS_IN_DESCRIBE_REQUEST = 50;

    private final String region;
    private final AmazonEC2 amazonEC2;
    private final AmazonAutoScaling amazonAutoScaling;
    private final Awaitilities awaitilities;
    private final long defaultPollInterval;

    public Ec2Discovery(String region, AmazonEC2 amazonEC2, AmazonAutoScaling amazonAutoScaling,
            Awaitilities awaitilities, long defaultPollInterval) {
        this.region = region;
        this.amazonEC2 = amazonEC2;
        this.amazonAutoScaling = amazonAutoScaling;
        this.awaitilities = awaitilities;
        this.defaultPollInterval = defaultPollInterval;
    }

    public List<AwsInstance> getAllInstances() {
        return describeInstances(null);
    }

    @Nullable
    public AwsInstance getInstance(String instanceId) {
        return describeInstances(singletonList(instanceId)).stream().findFirst().orElse(null);
    }

    public List<AwsInstance> getInstances(List<String> instanceIds) {
        if (instanceIds.isEmpty()) {
            return emptyList();
        }

        return describeInstances(instanceIds);
    }

    public void waitForInstanceTerminated(String instanceId, long instanceToBeTerminatedWaitTime) {
        waitForInstanceInState(instanceId, instanceToBeTerminatedWaitTime, InstanceStateName.Terminated);
    }

    public String retrieveInstanceNameTag(String instanceId) {
        log.info("About to retrieve Name tag of instance {}", instanceId);
        return instancesStream(singletonList(instanceId))
                .flatMap(i -> i.getTags().stream())
                .filter(tag -> tag.getKey().equals("Name"))
                .findFirst()
                .map(Tag::getValue)
                .orElse("");
    }

    private void waitForInstanceInState(String instanceId, long waitTime, InstanceStateName expectedState) {
        log.info("Waiting for instance {} to enter in {} state", instanceId, expectedState);
        awaitilities.awaitTillActionSucceed(waitTime, defaultPollInterval,
                format("instance %s to be in %s state", instanceId, expectedState.toString()),
                () -> instancesStream(singletonList(instanceId))
                        .findFirst()
                        .map(i -> InstanceStateName.fromValue(i.getState().getName()))
                        .filter(stateName -> stateName == expectedState)
                        .isPresent());
        log.info("Instances {} is in state: {}", instanceId, expectedState);
    }

    public AwsInstance waitForInstanceInState(String instanceId, Predicate<AwsInstance> statePredicate, String description, Duration maxWait) {
        return awaitilities.awaitForValue(
                (int) maxWait.toMillis(),
                (int) defaultPollInterval,
                () -> describeInstances(singletonList(instanceId)).stream().findFirst()
                        .filter(statePredicate).orElse(null),
                description
        );
    }

    private List<AwsInstance> describeInstances(List<String> instanceIds) {
        Map<String, AutoScalingInstanceDetails> autoScalingDetailsById =
                findAutoScalingInstanceDetails(instanceIds).stream()
                        .collect(toMap(AutoScalingInstanceDetails::getInstanceId, identity()));

        if (instanceIds != null && autoScalingDetailsById.size() != instanceIds.size()) {
            log.debug("Did not match all instances to auto-scaling details. Expected: %s, actual: %s",
                    instanceIds, autoScalingDetailsById.keySet());
        }

        List<Instance> instances = instancesStream(instanceIds).collect(toList());

        if (instanceIds != null && instances.size() != instanceIds.size()) {
            log.debug("Did not found all requested instances. Expected: %s, actual: %s",
                    instanceIds, instances.stream().map(Instance::getInstanceId).collect(toList()));
        }

        List<String> spotRequestIds =
                instances.stream().filter(instance -> "spot".equals(instance.getInstanceLifecycle()))
                        .map(Instance::getSpotInstanceRequestId).collect(toList());

        Map<String, SpotInstanceRequest> spotRequestsDetailsById =
                findSpotInstanceRequestDetails(spotRequestIds).stream()
                        .collect(toMap(SpotInstanceRequest::getInstanceId, identity()));

        return instances.stream().map(i -> toAwsInstance(i,
                autoScalingDetailsById.get(i.getInstanceId()),
                spotRequestsDetailsById.get(i.getInstanceId())))
                .collect(toList());
    }

    private Stream<Instance> instancesStream(List<String> instanceIds) {
        return amazonEC2.describeInstances(new DescribeInstancesRequest().withInstanceIds(instanceIds))
                .getReservations().stream()
                .flatMap(r -> r.getInstances().stream());
    }

    private AwsInstance toAwsInstance(Instance instance,
            AutoScalingInstanceDetails autoScalingInstanceDetails,
            SpotInstanceRequest spotInstanceRequest) {
        String state = instance.getState() != null ? instance.getState().getName() : null;
        IamInstanceProfile iamInstanceProfile = instance.getIamInstanceProfile();

        AwsInstance.Builder builder = new AwsInstance.Builder()
                .withInstanceId(instance.getInstanceId())
                .withPublicDnsName(instance.getPublicDnsName())
                .withPrivateIpAddress(instance.getPrivateIpAddress())
                .withPublicIpAddress(instance.getPublicIpAddress())
                .withIamInstanceProfileArn(iamInstanceProfile != null ? iamInstanceProfile.getArn() : null)
                .withImageId(instance.getImageId())
                .withRegion(region)
                .withState(state)
                .withInstanceType(instance.getInstanceType())
                .withKeyName(instance.getKeyName())
                .withNetworkInterfaces(toAwsNetworkInterfaces(instance.getNetworkInterfaces()))
                .withSecurityGroups(toAwsSecurityGroups(instance.getSecurityGroups()))
                .withSubnetId(instance.getSubnetId())
                .withLaunchTime(instance.getLaunchTime())
                .withLifecycle(instance.getInstanceLifecycle())
                .withTags(toAwsResourceTags(instance.getTags()))
                .withBlockDeviceMapping(toBlockDeviceMappings(instance.getBlockDeviceMappings()));

        if (autoScalingInstanceDetails != null) {
            builder.withAutoScalingDetails(new AwsAutoScalingDetails.Builder()
                    .withLaunchConfigurationName(autoScalingInstanceDetails.getLaunchConfigurationName())
                    .withAvailabilityZone(autoScalingInstanceDetails.getAvailabilityZone())
                    .withLifecycleState(autoScalingInstanceDetails.getLifecycleState())
                    .build());
        }

        if (spotInstanceRequest != null) {
            builder.withSpotRequestDetails(new AwsSpotRequestDetails.Builder()
                    .withRequestId(spotInstanceRequest.getSpotInstanceRequestId())
                    .withStatus(spotInstanceRequest.getStatus().getCode())
                    .withStatusUpdateTime(spotInstanceRequest.getStatus().getUpdateTime())
                    .withProductDescription(spotInstanceRequest.getProductDescription())
                    .build());
        }

        return builder.build();
    }

    private List<AwsSecurityGroup> toAwsSecurityGroups(List<GroupIdentifier> securityGroups) {
        return securityGroups == null ? emptyList() :
                securityGroups.stream()
                        .map(sg -> new AwsSecurityGroup(sg.getGroupId()).withGroupName(sg.getGroupName()))
                        .collect(toList());
    }

    private List<AwsResourceTag> toAwsResourceTags(List<Tag> tags) {
        return tags.stream()
                .map(tag -> AwsResourceTag.create(tag.getKey(), tag.getValue()))
                .collect(toList());
    }

    private List<AwsNetworkInterface> toAwsNetworkInterfaces(List<InstanceNetworkInterface> networkInterfaces) {
        return networkInterfaces.stream()
                .map(this::toAwsNetworkInterface)
                .collect(toList());
    }

    private AwsNetworkInterface toAwsNetworkInterface(InstanceNetworkInterface networkInterface) {
        return new AwsNetworkInterface.Builder()
                .withNetworkInterfaceId(networkInterface.getNetworkInterfaceId())
                .withPrivateIpAddresses(toPrivateIpAddresses(networkInterface.getPrivateIpAddresses()))
                .withAttachment(toAttachment(networkInterface.getAttachment()))
                .build();
    }

    private AwsNetworkInterfaceAttachment toAttachment(InstanceNetworkInterfaceAttachment attachment) {
        return attachment == null ? null
                : new AwsNetworkInterfaceAttachment.Builder()
                .withDeviceIndex(attachment.getDeviceIndex())
                .build();
    }

    private List<AwsPrivateIpAddress> toPrivateIpAddresses(List<InstancePrivateIpAddress> privateIpAddresses) {
        return privateIpAddresses.stream()
                .map(this::toPrivateIpAddress)
                .collect(toList());
    }

    private AwsPrivateIpAddress toPrivateIpAddress(InstancePrivateIpAddress privateIpAddress) {
        return new AwsPrivateIpAddress.Builder()
                .withPrivateIpAddress(privateIpAddress.getPrivateIpAddress())
                .withAssociation(toAssociation(privateIpAddress.getAssociation()))
                .build();
    }

    private AwsPrivateIpAddressAssociation toAssociation(InstanceNetworkInterfaceAssociation association) {
        return association == null ? null
                : new AwsPrivateIpAddressAssociation.Builder()
                .withOwnerId(association.getIpOwnerId())
                .build();
    }

    private List<AutoScalingInstanceDetails> findAutoScalingInstanceDetails(List<String> instanceIds) {
        if (instanceIds == null) {
            return findAutoScalingInstanceDetailsForChunk(null);
        } else {
            return findInChunks(partition(instanceIds, MAX_NUMBER_OF_INSTANCE_IDS_IN_DESCRIBE_REQUEST),
                    this::findAutoScalingInstanceDetailsForChunk);
        }
    }

    private List<SpotInstanceRequest> findSpotInstanceRequestDetails(List<String> spotRequestIds) {
        if (spotRequestIds == null) {
            return Collections.emptyList();
        }
        return findInChunks(partition(spotRequestIds, MAX_NUMBER_OF_INSTANCE_IDS_IN_DESCRIBE_REQUEST),
                this::findSpotInstanceRequestsForChunk);
    }

    private <T> List<T> findInChunks(List<List<String>> chunks, Function<List<String>, List<T>> findAction) {
        return chunks.stream().map(findAction).flatMap(Collection::stream).collect(toList());
    }

    private List<AutoScalingInstanceDetails> findAutoScalingInstanceDetailsForChunk(List<String> instanceIds) {
        List<AutoScalingInstanceDetails> instances = Lists.newArrayList();
        String nextToken = null;
        do {
            DescribeAutoScalingInstancesRequest request = new DescribeAutoScalingInstancesRequest()
                    .withNextToken(nextToken)
                    .withInstanceIds(instanceIds);
            DescribeAutoScalingInstancesResult result = amazonAutoScaling.describeAutoScalingInstances(request);
            nextToken = result.getNextToken();
            instances.addAll(result.getAutoScalingInstances());
        } while (nextToken != null);
        return instances;
    }

    private List<SpotInstanceRequest> findSpotInstanceRequestsForChunk(List<String> requestIds) {
        return amazonEC2.describeSpotInstanceRequests(new DescribeSpotInstanceRequestsRequest()
                .withSpotInstanceRequestIds(requestIds)).getSpotInstanceRequests();
    }

    private List<AwsInstanceBlockDeviceMapping> toBlockDeviceMappings(List<InstanceBlockDeviceMapping> mappings) {
        return mappings.stream().map(this::toBlockDeviceMapping).collect(toList());
    }

    private AwsInstanceBlockDeviceMapping toBlockDeviceMapping(InstanceBlockDeviceMapping mapping) {
        return AwsInstanceBlockDeviceMapping.create(
                mapping.getDeviceName(),
                toEbs(mapping.getEbs())
        );
    }

    private AwsInstanceEbs toEbs(EbsInstanceBlockDevice ebs) {
        return AwsInstanceEbs.create(
                ebs.getAttachTime(),
                ebs.getDeleteOnTermination(),
                ebs.getStatus(),
                ebs.getVolumeId()
        );
    }
}
