package dev.langchain4j.model.vertexai;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.vertexai.spi.VertexAiImageModelBuilderFactory;
import dev.langchain4j.spi.ServiceHelper;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/model/vertexai/VertexAiImageModel.class */
public class VertexAiImageModel implements ImageModel {
    private final Long seed;
    private final String endpoint;
    private final EndpointName endpointName;
    private final String language;
    private final Integer guidanceScale;
    private final String negativePrompt;
    private final ImageStyle sampleImageStyle;
    private final Integer sampleImageSize;
    private final int maxRetries;
    private final Boolean withPersisting;
    private Path tempDirectory;

    /* loaded from: input_file:dev/langchain4j/model/vertexai/VertexAiImageModel$Builder.class */
    public static class Builder {
        private String endpoint;
        private String project;
        private String location;
        private String publisher;
        private String modelName;
        private Long seed;
        private String language;
        private String negativePrompt;
        private ImageStyle sampleImageStyle;
        private Integer sampleImageSize;
        private Integer maxRetries;
        private Integer guidanceScale;
        private Boolean withPersisting;
        private Path persistTo;

        public Builder endpoint(String str) {
            this.endpoint = str;
            return this;
        }

        public Builder project(String str) {
            this.project = str;
            return this;
        }

        public Builder location(String str) {
            this.location = str;
            return this;
        }

        public Builder publisher(String str) {
            this.publisher = str;
            return this;
        }

        public Builder modelName(String str) {
            this.modelName = str;
            return this;
        }

        public Builder seed(Long l) {
            this.seed = l;
            return this;
        }

        public Builder language(String str) {
            this.language = str;
            return this;
        }

        public Builder guidanceScale(Integer num) {
            this.guidanceScale = num;
            return this;
        }

        public Builder negativePrompt(String str) {
            this.negativePrompt = str;
            return this;
        }

        public Builder sampleImageStyle(ImageStyle imageStyle) {
            this.sampleImageStyle = imageStyle;
            return this;
        }

        public Builder sampleImageSize(Integer num) {
            this.sampleImageSize = num;
            return this;
        }

        public Builder maxRetries(Integer num) {
            this.maxRetries = num;
            return this;
        }

        public Builder withPersisting() {
            this.withPersisting = Boolean.TRUE;
            return this;
        }

        public Builder persistTo(Path path) {
            this.persistTo = path;
            return withPersisting();
        }

        public VertexAiImageModel build() {
            return new VertexAiImageModel(this.endpoint, this.project, this.location, this.publisher, this.modelName, this.seed, this.language, this.guidanceScale, this.negativePrompt, this.sampleImageStyle, this.sampleImageSize, this.maxRetries, this.withPersisting, this.persistTo);
        }
    }

    /* loaded from: input_file:dev/langchain4j/model/vertexai/VertexAiImageModel$ImageStyle.class */
    public enum ImageStyle {
        photograph,
        digital_art,
        landscape,
        sketch,
        watercolor,
        cyberpunk,
        pop_art
    }

    public VertexAiImageModel(String str, String str2, String str3, String str4, String str5, Long l, String str6, Integer num, String str7, ImageStyle imageStyle, Integer num2, Integer num3, Boolean bool, Path path) {
        this.endpoint = ValidationUtils.ensureNotBlank(str, "endpoint");
        this.endpointName = EndpointName.ofProjectLocationPublisherModelName(ValidationUtils.ensureNotBlank(str2, "project"), ValidationUtils.ensureNotBlank(str3, "location"), ValidationUtils.ensureNotBlank(str4, "publisher"), ValidationUtils.ensureNotBlank(str5, "modelName"));
        this.seed = l == null ? null : Long.valueOf(ValidationUtils.ensureBetween(l, 0L, 4294967295L, "seed"));
        this.language = str6;
        this.guidanceScale = num;
        this.negativePrompt = str7;
        this.sampleImageStyle = imageStyle;
        this.sampleImageSize = num2;
        this.maxRetries = num3 == null ? 3 : num3.intValue();
        this.withPersisting = bool;
        if (this.withPersisting == null || !this.withPersisting.booleanValue()) {
            return;
        }
        try {
            if (path == null) {
                this.tempDirectory = Files.createTempDirectory("imagen-directory-", new FileAttribute[0]);
            } else {
                if (!path.toFile().exists() && !path.toFile().mkdirs()) {
                    throw new IOException("Impossible to create persistTo temporary directory");
                }
                this.tempDirectory = path;
            }
        } catch (IOException e) {
            throw new RuntimeException("Impossible to create persistence temporary directory", e);
        }
    }

    public Response<Image> generate(String str) {
        Response<List<Image>> generate = generate(str, 1);
        return Response.from((Image) ((List) generate.content()).get(0), generate.tokenUsage(), generate.finishReason());
    }

    public Response<List<Image>> generate(String str, int i) {
        return generate(str, null, null, i);
    }

    private Response<List<Image>> generate(String str, Image image, Image image2, int i) {
        try {
            PredictionServiceClient create = PredictionServiceClient.create(PredictionServiceSettings.newBuilder().setEndpoint(this.endpoint).build());
            try {
                List<Value> prepareInstance = prepareInstance(str, image, image2);
                Value prepareParameters = prepareParameters(i);
                Response<List<Image>> from = Response.from((List) ((PredictResponse) RetryUtils.withRetry(() -> {
                    return create.predict(this.endpointName, prepareInstance, prepareParameters);
                }, this.maxRetries)).getPredictionsList().stream().map(value -> {
                    String stringValue = ((Value) value.getStructValue().getFieldsMap().get("bytesBase64Encoded")).getStringValue();
                    return Image.builder().base64Data(stringValue).url(persistAndGetURI(stringValue)).build();
                }).collect(Collectors.toList()));
                if (create != null) {
                    create.close();
                }
                return from;
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private Value prepareParameters(int i) throws InvalidProtocolBufferException {
        HashMap hashMap = new HashMap();
        hashMap.put("sampleCount", Integer.valueOf(i));
        if (this.seed != null) {
            hashMap.put("seed", this.seed);
        }
        if (this.sampleImageStyle != null) {
            hashMap.put("sampleImageStyle", this.sampleImageStyle.name());
        }
        if (this.sampleImageSize != null) {
            hashMap.put("mode", "upscale");
            hashMap.put("sampleImageSize", this.sampleImageSize.toString());
        }
        if (this.guidanceScale != null) {
            hashMap.put("guidanceScale", this.guidanceScale);
        }
        if (this.negativePrompt != null) {
            hashMap.put("negativePrompt", this.negativePrompt);
        }
        if (this.language != null) {
            hashMap.put("language", this.language);
        }
        Value.Builder newBuilder = Value.newBuilder();
        JsonFormat.parser().merge(Json.toJson(hashMap), newBuilder);
        return newBuilder.build();
    }

    private List<Value> prepareInstance(String str, Image image, Image image2) throws InvalidProtocolBufferException {
        HashMap hashMap = new HashMap();
        hashMap.put("prompt", str);
        if (image != null && image.base64Data() != null) {
            HashMap hashMap2 = new HashMap();
            hashMap2.put("bytesBase64Encoded", image.base64Data());
            hashMap.put("image", hashMap2);
        }
        if (image2 != null && image2.base64Data() != null) {
            HashMap hashMap3 = new HashMap();
            hashMap3.put("bytesBase64Encoded", image2.base64Data());
            HashMap hashMap4 = new HashMap();
            hashMap4.put("image", hashMap3);
            hashMap.put("mask", hashMap4);
        }
        Value.Builder newBuilder = Value.newBuilder();
        JsonFormat.parser().merge(Json.toJson(hashMap), newBuilder);
        return Collections.singletonList(newBuilder.build());
    }

    public Response<Image> edit(Image image, String str) {
        Response<Image> edit = edit(image, null, str);
        return Response.from((Image) edit.content(), edit.tokenUsage(), edit.finishReason());
    }

    public Response<Image> edit(Image image, Image image2, String str) {
        Response<List<Image>> generate = generate(str, image, image2, 1);
        return Response.from((Image) ((List) generate.content()).get(0), generate.tokenUsage(), generate.finishReason());
    }

    private URI persistAndGetURI(String str) {
        if (this.withPersisting == null || !this.withPersisting.booleanValue()) {
            return null;
        }
        try {
            Path createTempFile = Files.createTempFile(this.tempDirectory, "imagen-image-", ".png", new FileAttribute[0]);
            Files.write(createTempFile, Base64.getDecoder().decode(str), new OpenOption[0]);
            return createTempFile.toUri();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static Builder builder() {
        Iterator it = ServiceHelper.loadFactories(VertexAiImageModelBuilderFactory.class).iterator();
        return it.hasNext() ? ((VertexAiImageModelBuilderFactory) it.next()).get() : new Builder();
    }
}
