/*
 * Copyright (c) 2019,2020 Dawid Walczak.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package pl.metaprogramming.codemodel.builder.java

import pl.metaprogramming.codemodel.builder.java.mapper.FixedMethodCmBuilder
import pl.metaprogramming.codemodel.formatter.CodeBuffer
import pl.metaprogramming.codemodel.formatter.JavaCodeFormatter
import pl.metaprogramming.codemodel.model.java.*
import pl.metaprogramming.codemodel.model.java.index.ClassIndex
import pl.metaprogramming.codemodel.model.java.index.DataTypeMapper
import pl.metaprogramming.codemodel.model.java.index.MapperEntry
import pl.metaprogramming.metamodel.data.ArrayType
import pl.metaprogramming.metamodel.data.DataType
import pl.metaprogramming.metamodel.data.MapType

import java.util.function.Function
import java.util.function.Predicate

import static pl.metaprogramming.codemodel.model.java.JavaDefs.*

/**
 * Build strategy class should implements at least one of methods:
 *
 * <ul>
 *     <li>makeDeclaration</li>
 *     <li>makeImplementation</li>
 *     <li>makeDecoration</li>
 * </ul>
 * @param < T > class representing the model for which the code is generated
 */
class ClassCmBuildStrategy<T> {

    protected ClassCmBuilder<T> builder
    protected Function<?, T> modelMapper
    CodeBuffer codeBuf = new CodeBuffer()


    void makeDeclaration() {
        // not every build strategy needs to implement this method
    }

    void makeImplementation() {
        // not every build strategy needs to implement this method
    }

    void makeDecoration() {
        // not every build strategy needs to implement this method
    }

    ClassCmBuildStrategy<T> getInstance(ClassCmBuilder<T> builder) {
        init(builder)
    }

    protected ClassCmBuildStrategy<T> init(ClassCmBuilder<T> builder) {
        this.builder = builder
        this
    }

    T getModel() {
        modelMapper ? modelMapper.apply(builder.model) : builder.model
    }

    ClassCm getClassModel() {
        builder.classCd
    }

    String getModelName() {
        builder.modelName
    }

    Object getClassType() {
        builder.classType
    }

    public <T> T getParams(Class<T> clazz) {
        builder.params.get(clazz)
    }

    void addImport(ClassCd classCd) {
        if (classCd.packageName) {
            classModel.imports.add(classCd.getCanonicalName())
        }
    }

    void addImportStatic(ClassCd classCd) {
        classModel.imports.add("static ${classCd.getCanonicalName()}.*".toString())
    }

    void addImports(List<String> imports) {
        classModel.imports.addAll(imports)
    }

    void addImports(Object... imports) {
        imports.each {
            if (it instanceof String) {
                classModel.imports.add(it as String)
            } else if (it instanceof ClassCd && it.packageName) {
                def classCd = it as ClassCd
                classModel.imports.add(classCd.canonicalName)
                if (classCd.genericParams) {
                    addImports(classCd.genericParams.toArray())
                }
            }
        }
    }

    void setInterface() {
        classModel.isInterface = true
    }

    void setModifiers(String modifiers) {
        classModel.modifiers = modifiers
    }

    void addInterfaces(Object... interfaces) {
        interfaces.each {
            classModel.interfaces.add(getClass(it))
        }
    }

    List<ClassCd> getInterfaces() {
        classModel.interfaces
    }

    void setEnums(List<EnumItemCm> enums) {
        classModel.isEnum = true
        classModel.enumItems = enums
    }

    void addGenericParams(ClassCd... genericParams) {
        classModel.genericParams.addAll(Arrays.asList(genericParams))
    }

    ClassCd getSuperClass() {
        return classModel.extend
    }

    void setSuperClass(Object classType, def model = null) {
        def superClass = getClass(classType, model)
        check(classModel.extend == null, "Can't set super class '$superClass', super class already set to '$classModel.extend'")
        classModel.extend = superClass
    }


    void setComment(String comment) {
        classModel.description = comment
    }

    void addAnnotation(AnnotationCm annotation) {
        classModel.annotations.add(annotation)
    }

    void addAnnotations(List<AnnotationCm> annotations) {
        classModel.annotations.addAll(annotations)
    }

    void addMethods(List<MethodCm> methods) {
        methods.each { builder.addMethod(it) }
    }

    void addMethods(MethodCm... methods) {
        addMethods(methods.toList())
    }

    List<MethodCm> getMethods() {
        classModel.methods
    }

    void addMappers(List<MethodCm> methods) {
        methods.each {
            addMapper(new FixedMethodCmBuilder(methodCm: it))
        }
    }

    void addMappers(MethodCm... methods) {
        addMappers(methods.toList())
    }

    void addMappers(MethodCmBuilder... mapperBuilders) {
        mapperBuilders.each { addMapper(it) }
    }

    void addMapper(MethodCmBuilder mapperBuilder) {
        classIndex.putMapper(bind(mapperBuilder))
    }

    MethodCmBuilder bind(MethodCmBuilder mapperBuilder) {
        mapperBuilder.strategy = this
        mapperBuilder.builder = builder
        mapperBuilder.methodCm.ownerClass = classModel
        mapperBuilder
    }

    MethodCm findMapper(ClassCd resultType, List<ClassCd> params, boolean failIfNotFound = true) {
        def key = new MapperEntry.Key(to: resultType, from: params)
        def result = classIndex.getMapper(key)
        if (result) {
            result.methodCm
        } else {
            check(!failIfNotFound, "Can't find mapper $key")
            null
        }
    }

    void addFields(List<FieldCm> fields) {
        fields.each { classModel.addField(it) }
    }

    void addFields(FieldCm... fields) {
        addFields(fields.toList())
    }

    FieldCm addField(String name, ClassCd type) {
        def field = new FieldCm(type, name)
        addFields(field)
        field
    }

    List<FieldCm> findFields(Predicate<FieldCm> predicate) {
        classModel.fields.findAll { predicate.test(it) }
    }


    FieldCm injectDependency(ClassCd toInject) {
        check(toInject != null, "toInject can't be null")
        if (classModel == toInject) {
            return new FieldCm(toInject, 'this')
        }
        FieldCm existField = classModel.fields.find { it.type == toInject }
        if (existField) {
            return existField
        }
        def fieldName = toInject.isInterface && toInject.className ==~ /I[A-Z].*/
                ? toInject.className.substring(1).uncapitalize()
                : toInject.className.uncapitalize()
        def newField = new FieldCm(toInject, fieldName)
        classModel.fields.add(newField)
        newField
    }

    ClassCd findClass(Object classType, Object model = builder.model, boolean optional = false) {
        getClass(classType, model, optional, false)
    }

    ClassCd getClass(Object classType, Object model = builder.model, boolean optional = false, boolean markAsUsed = true) {
        if (classType instanceof String) {
            return new ClassCd(classType)
        }
        if (classType instanceof ClassCd) {
            return classType
        }
        if (classType instanceof DataType) {
            return config.dataTypeMapper.map(classType)
        }
        if (model instanceof DataType) {
            return findDataSchemaClass(model, classType, markAsUsed)
        }
        def classCd = classIndex.getClass(classType, model, markAsUsed)
        check(classCd != null || optional, "Can't find Class ($classType) for $model")
        classCd
    }

    ClassCm getClassCm(Object classType, Object model = builder.model, boolean optional = false, boolean markAsUsed = true) {
        getClass(classType, model, optional, markAsUsed) as ClassCm
    }

    Optional<ClassCd> getClassOptional(Object classType, def model = builder.model) {
        Optional.ofNullable(getClass(classType, model, true))
    }

    ClassCd getGenericClass(Object classType, List genericClassTypes) {
        new ClassCd(getClass(classType), genericClassTypes.collect { getClass(it) })
    }

    ClassCd findDataSchemaClass(DataType dataType, def classType, boolean doNotMarkAsUsed) {
        assert dataType != null
        try {
            def classCd = getClassForDataType(dataType, classType, config.dataTypeMapper, doNotMarkAsUsed)
            check classCd != null, "Can't find class for $dataType (of type $classType)"
            return classCd
        } catch (Exception e) {
            panic("Can't find class for $dataType (of type $classType)", e)
        }
    }


    ClassCd getClassForDataType(DataType dataType, def classType, DataTypeMapper dataTypeMapper, boolean doNotMarkAsUsed) {
        if (dataType && dataTypeMapper) {
            def result = dataTypeMapper.map(dataType, classType)
            if (result) {
                return result
            }
        }
        if (dataType instanceof ArrayType) {
            def genericParam = getClassForDataType(dataType.itemsSchema.dataType, classType, dataTypeMapper, doNotMarkAsUsed)
            check genericParam != null, "Undefined item type for collection field: $dataType"
            new ClassCd(T_LIST, [genericParam])
        } else if (dataType instanceof MapType) {
            def genericParam = getClassForDataType(dataType.valuesSchema.dataType, classType, dataTypeMapper, doNotMarkAsUsed)
            check genericParam != null, "Can't find class for $dataType.valuesSchema.dataType (of type $classType)"
            new ClassCd(T_MAP, [T_STRING, genericParam])
        } else {
            classIndex.getClass(classType, dataType, doNotMarkAsUsed)
        }
    }

    // ------------------------------------------------------------------------
    // helper methods for methods bodies creation
    // ------------------------------------------------------------------------

    String addVarDeclaration(FieldCm fieldCm) {
        addImport(fieldCm.type)
        "$fieldCm.type.className $fieldCm.name = $fieldCm.value;"
    }

    String addVarDeclaration(String fieldName, ClassCd fieldType, String fieldExp) {
        addImport(fieldType)
        String fixedFieldExp = fieldExp == 'new'
                ? "new $fieldType.className()" : fieldExp
        codeBuf.addLines("$fieldType.className $fieldName = $fixedFieldExp;")
        fieldName
    }

    String callComponent(def classType, String callExp) {
        injectDependency(getClass(classType)).name + '.' + callExp
    }

    FieldCm transform(FieldCm from, def toClassType, String variableName = null) {
        transform([from], toClassType, variableName)
    }

    FieldCm transform(def fromClassType, def toClassType, def metamodel, String fromValueExp) {
        def fromClass = getClass(fromClassType, metamodel)
        def toClass = getClass(toClassType, metamodel)
        transform(new FieldCm(fromClass).assign(fromValueExp), toClass)
    }

    FieldCm transform(List<FieldCm> from, def toClassType, String variableName = null) {
        def type = getClass(toClassType)
        new FieldCm(type, variableName).assign(makeTransformation(type, from))
    }

    String transformWithMethodRef(ClassCd from, ClassCd to) {
        makeTransformation(to, [from].collect { new FieldCm(type: it) }, 1, true)
    }

    String makeTransformation(
            ClassCd toType,
            List<FieldCm> from,
            int level = 1,
            boolean allowMethodReference = false) {
        if (isCollectionTransformation(toType, from)) {
            makeCollectionTransformation(toType, from, level)
        } else if (from.size() == 1 && from[0].type.isEnum && toType == T_STRING) {
            makeEnumToStringTransformation(from[0])
        } else if ([toType] == from.type) {
            makeTransformationParams(from)
        } else {
            def mapper = findMapper(toType, from.type, from.size() != 0)
            if (mapper == null) {
                addImport(toType)
                return "new ${toType.className}()"
            }
            if (mapper.static) {
                addImport(mapper.ownerClass)
                makeCall(mapper.ownerClass.className, mapper.name, from, allowMethodReference)
            } else {
                def mapperField = injectDependency(mapper.ownerClass)
                makeCall(mapperField.name, mapper.name, from, allowMethodReference)
            }
        }
    }

    private String makeCall(String instance, String method, List<FieldCm> from, boolean allowMethodReference) {
        if (allowMethodReference && from.size() == 1) {
            "${instance}::${method}"
        } else {
            def call = "${method}(${makeTransformationParams(from)})"
            instance == 'this' ? call : "${instance}.${call}"
        }
    }

    String makeCollectionTransformation(ClassCd toType, List<FieldCm> from, int level) {
        def isList = toType.isClass(T_LIST)
        // e.g. baseDataMapper.transformList(raw.getAuthors(), v -> baseDataMapper.toLong(v))
        def toItemType = toType.genericParams.get(isList ? 0 : 1)
        def fromItemType = from[0].type.genericParams.get(isList ? 0 : 1)
        def varName = 'v' + (level > 1 ? level : '')
        def itemFrom = [new FieldCm(fromItemType, varName)] + from.subList(1, from.size())
        def itemTransformation = makeTransformation(toItemType, itemFrom, level + 1, true)

        if (itemTransformation) {
            def mapper = isList ? findMapper(LIST_R, [LIST_T, FUN_T_R]) : findMapper(MAP_KR, [MAP_KT, FUN_T_R])
            def itemTransformationValue = varName != itemTransformation && !itemTransformation.contains('.') ? itemTransformation : "${varName} -> ${itemTransformation}"
            String methodCall = "${mapper.name}(${makeTransformationParams([from[0]])}, $itemTransformationValue)"
            if (mapper.static) {
                addImport(mapper.ownerClass)
                "${mapper.ownerClass.className}.$methodCall"
            } else {
                def mapperField = injectDependency(mapper.ownerClass)
                "${mapperField.name}.$methodCall"
            }
        } else {
            null
        }
    }

    String makeEnumToStringTransformation(FieldCm from) {
        if (from.nonnull) {
            "${from.name}.getValue()"
        } else {
            addImports(OPTIONAL, from.type)
            "Optional.ofNullable(${getter(from.name)}).map(${from.type.className}::getValue).orElse(null)"
        }
    }

    boolean isCollectionTransformation(ClassCd toType, List<FieldCm> from) {
        (toType.isClass(T_LIST) && from[0].type.isClass(T_LIST)) ||
                (toType.isClass(T_MAP) && from[0].type.isClass(T_MAP))
    }

    String makeTransformationParams(List<FieldCm> from) {
        from.collect { it.name ? getter(it.name) : it.value }.join(', ')
    }


    String getter(String path) {
        getter(Arrays.asList(path.split('\\.')))
    }

    String getter(List<String> path) {
        path[0] + path.drop(1).collect {
            '.get' + JavaCodeFormatter.toJavaName(it, true) + '()'
        }.join('')
    }

    private ClassIndex getClassIndex() {
        builder.classIndex
    }

    private ClassBuilderConfigurator getConfig() {
        builder.config
    }

    protected void check(boolean isOk, String errorMessage) {
        if (!isOk) {
            panic(errorMessage)
        }
    }

    protected void panic(String message, Exception cause = null) {
        classIndex.printIndex()
        throw new IllegalStateException("Can't build $classModel ($config.classType for $model)\n$message", cause)
    }

}