/*
 * Copyright (c) 2019 Dawid Walczak.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package pl.metaprogramming.codemodel.builder.java.rest

import groovy.transform.TypeChecked
import groovy.transform.TypeCheckingMode
import groovy.util.logging.Slf4j
import pl.metaprogramming.codemodel.builder.java.ClassCmBuildStrategy
import pl.metaprogramming.codemodel.builder.java.base.BaseEnumBuildStrategy
import pl.metaprogramming.codemodel.builder.java.config.ValidationParams
import pl.metaprogramming.codemodel.builder.java.spring.mapper.DefaultValueMapper
import pl.metaprogramming.codemodel.formatter.JavaCodeFormatter
import pl.metaprogramming.codemodel.model.java.*
import pl.metaprogramming.metamodel.data.DataSchema
import pl.metaprogramming.metamodel.data.ObjectType
import pl.metaprogramming.metamodel.data.constraints.DataConstraint
import pl.metaprogramming.metamodel.data.constraints.DictionaryConstraint
import pl.metaprogramming.metamodel.oas.Parameter

import static pl.metaprogramming.codemodel.builder.java.ClassType.*
import static pl.metaprogramming.codemodel.builder.java.MetaModelAttribute.DESCRIPTION_FIELD_NAME
import static pl.metaprogramming.codemodel.model.java.JavaDefs.T_STRING
import static pl.metaprogramming.metamodel.data.DataTypeCode.*

@Slf4j
abstract class ValidatorBuildStrategy extends ClassCmBuildStrategy<ObjectType> {

    def dtoTypeOfCode
    def validatorTypeOfCode
    boolean addDataTypeValidations
    def enumTypeOfCode

    protected ClassCm getDtoClass() {
        getClass(dtoTypeOfCode) as ClassCm
    }

    DataSchema currentFieldSchema

    void addCheckMethod() {
        addMethods(prepareCheckMethod())
    }

    protected MethodCm prepareCheckMethod() {
        def genericParams = superClass ? superClass.genericParams : interfaces[0].genericParams
        new MethodCm(
                name: 'check',
                params: [new FieldCm(
                        name: 'ctx',
                        type: getGenericClass(VALIDATION_CONTEXT, genericParams))],
                implBody: prepareImplBody()
        )
    }

    protected String prepareImplBody() {
        model.additives['x-validation-beans']?.each {
            setValidationBean(it as Map<String, String>)
        }
        if (model.inherits) {
            codeBuf.addLines("${getObjectChecker(model.inherits[0])}.checkWithParent(ctx);")
        }
        model.fields.each {
            addValidation(it)
        }
        addXValidations()
        codeBuf.take()
    }

    protected void setValidationBean(Map<String, String> spec) {
        assert spec['class']
        assert spec['factory']
        def beanClass = new ClassCd(spec['class'])
        def factoryDesc = spec['factory'].split(':')
        assert factoryDesc.length == 2
        addImports(beanClass)
        def factoryField = injectDependency(new ClassCd(factoryDesc[0]))
        codeBuf.newLine("ctx.setBean(${beanClass.className}.class, $factoryField.name::${factoryDesc[1]});")
    }

    protected void addValidation(DataSchema schema) {
        currentFieldSchema = schema
        def fieldClass = dtoClass.fields.find { it.model == schema }.type
        def checkers = getCheckers(schema, fieldClass)
        if (checkers) {
            addImportStatic(dtoClass)
            addImportStatic(getClass(VALIDATION_COMMON_CHECKERS))
            codeBuf.newLine("ctx.check(${getDescription()}, ${checkers.join(', ')});")
        }
    }

    protected List<String> getCheckers(DataSchema schema, ClassCd type, int level = 1) {
        def xValidations = ValidationDirectives.forSchema(schema)
        List<String> checkers = []
        addConstraintsCheckers(checkers, xValidations.highPriority)
        if (schema.isRequired) {
            checkers.add('required()')
        }
        if (schema.isArray()) {
            addArrayChecker(checkers, schema, type, level)
            return checkers
        }
        if (schema.isMap()) {
            addMapChecker(checkers, schema, type, level)
            return checkers
        }
        if (addDataTypeValidations && schema.isType(BOOLEAN, INT32, INT64, FLOAT, DOUBLE)) {
            checkers.add(schema.dataType.typeCode.name())
        }
        addObjectChecker(checkers, schema)
        addEnumChecker(checkers, schema)
        addFormatChecker(checkers, schema)
        addPatternChecker(checkers, schema)
        addLengthChecker(checkers, schema)
        addConstraintsCheckers(checkers, xValidations.lowPriority)
        // validation that requires data conversion should be perform at the end
        addMinMaxChecker(checkers, schema)
        checkers
    }

    @TypeChecked(TypeCheckingMode.SKIP)
    protected void addFormatChecker(List<String> checkers, DataSchema schema) {
        def validator = getParams(ValidationParams).formatValidators.getValidator(schema.format)
        if (validator) {
            checkers.add(toJavaExpr(validator))
        }
    }

    static protected Collection<String> getValidations(Map additives) {
        additives['x-validations']?.findAll { it instanceof CharSequence }
                ?: Collections.emptyList()
    }

    protected void addConstraintsCheckers(List<String> checkers, List<DataConstraint> directives) {
        directives.each {
            if (it instanceof ValidationDirective) {
                checkers.add(toJavaExpr(it.pointer))
            }
            if (it instanceof DictionaryConstraint) {
                checkers.add(new DictionaryConstraintResolver().resolve(it))
            }
        }
    }

    protected void addXValidations() {
        getValidations(model.additives).collect { String desc ->
            if (ValidatorPointer.isValidatorPointerExpression(desc)) {
                toJavaExpr(ValidatorPointer.fromExpression(desc))
            } else if (CompareExpression.isCheckExpression(desc)) {
                def exp = CompareExpression.parse(desc, model)
                addDataTypeValidations
                        ? "${exp.operator.name()}(${getDescription(exp.field1)}, ${getDescription(exp.field2)}, ${getDeserializeMapper(exp.field1)})"
                        : "${exp.operator.name()}(${getDescription(exp.field1)}, ${getDescription(exp.field2)})"
            } else {
                null
            }
        }.findAll {
            it != null
        }.each {
            codeBuf.addLines("ctx.check(${it});")
        }
    }

    String getDescription(DataSchema schema = currentFieldSchema) {
        schema.getAdditive(DESCRIPTION_FIELD_NAME)
    }

    String getDeserializeMapper(DataSchema schema) {
        transformWithMethodRef(T_STRING, getClass(schema.dataType))
    }

    protected void addObjectChecker(List<String> checkers, DataSchema schema) {
        if (!schema.isObject() || schema instanceof Parameter) return
        checkers.add(getObjectChecker(schema.objectType))
    }

    protected String getObjectChecker(ObjectType objectType) {
        injectDependency(getClass(validatorTypeOfCode, objectType)).name
    }

    protected void addArrayChecker(List<String> checkers, DataSchema schema, ClassCd type, int level) {
        if (!schema.isArray()) return
        if (schema.arrayType.minItems > 0) {
            checkers.add('required()')
        }
        addListSizeChecker(checkers, schema, type, level)
        def itemCheckers = getCheckers(schema.arrayType.itemsSchema, type.genericParams[0], level + 1)
        if (!itemCheckers.isEmpty()) {
            checkers.add("list(${itemCheckers.join(', ')})".toString())
        }
    }

    protected void addMapChecker(List<String> checkers, DataSchema schema, ClassCd type, int level) {
        if (!schema.isMap()) return
        def itemCheckers = getCheckers(schema.mapType.valuesSchema, type.genericParams[1], level + 1)
        if (!itemCheckers.isEmpty()) {
            checkers.add("mapValues(${itemCheckers.join(', ')})".toString())
        }
    }

    protected void addEnumChecker(List<String> checkers, DataSchema schema) {
        if (!schema.isEnum() || !addDataTypeValidations) return
        def enumClass = getClass(enumTypeOfCode, schema.dataType)
        def checkerName = JavaCodeFormatter.toUpperCase(enumClass.className)
        if (findFields { it.name == checkerName }.isEmpty()) {
            addImport(enumClass)
            addCheckerField(checkerName, "allow(${enumClass.className}.values())")
        }
        checkers.add(checkerName)
    }

    protected void addPatternChecker(List<String> checkers, DataSchema schema) {
        if (!schema.pattern) return
        String checkerName = "${getDescription()}_PATTERN"
        addImports('java.util.regex.Pattern')
        String fixedPattern = schema.pattern.
                replace('\\', '\\\\').
                replace('"', '\\"')
        String errorCode = schema.constraints.invalidPatternCode
        addCheckerField(checkerName, "matches(Pattern.compile(\"$fixedPattern\")${errorCode ? (', "' + errorCode + '"') : ''})")
        checkers.add(checkerName)
    }

    protected void addMinMaxChecker(List<String> checkers, DataSchema schema) {
        if (!schema.minimum && !schema.maximum) return
        String checkerName = "${getDescription()}_MIN_MAX"
        def type = getClass(schema.dataType)
        addImports(type)
        if (addDataTypeValidations) {
            addCheckerField(checkerName, "range(${type.className}::new, ${ValueCm.escaped(schema.minimum)}, ${ValueCm.escaped(schema.maximum)})")
        } else {
            addCheckerField(checkerName, "range(${DefaultValueMapper.make(type, schema.minimum, this)}, ${DefaultValueMapper.make(type, schema.maximum, this)})", type)
        }
        checkers.add(checkerName)
    }

    protected void addLengthChecker(List<String> checkers, DataSchema schema) {
        if (!schema.minLength && !schema.maxLength) return
        String checkerName = "${getDescription()}_LENGTH"
        addCheckerField(checkerName, "length(${schema.minLength}, ${schema.maxLength})")
        checkers.add(checkerName)
    }

    protected void addListSizeChecker(List<String> checkers, DataSchema schema, ClassCd type, int level) {
        if (!schema.arrayType.minItems && !schema.arrayType.maxItems) return
        String checkerName = "${getDescription()}${level > 1 ? '_' + level : ''}_SIZE"
        addCheckerField(checkerName, "size(${schema.arrayType.minItems}, ${schema.arrayType.maxItems})", type)
        checkers.add(checkerName)
    }

    protected void addCheckerField(String checkerName, String checkerValue, ClassCd testedType = T_STRING) {
        addFields(new FieldCm(
                name: checkerName,
                modifiers: 'private static final',
                type: getGenericClass(VALIDATION_CHECKER, [testedType]),
                value: ValueCm.value(checkerValue)
        ))
    }

    String toJavaExpr(ValidatorPointer validator) {
        if (VALIDATION_COMMON_CHECKERS.name() == validator.classCanonicalName) {
            return validator.staticField
        }

        ClassCd validatorClass = new ClassCd(validator.classCanonicalName)
        addImports(validatorClass)

        if (validator.staticField) {
            return "${validatorClass.className}.${validator.staticField}"
        }

        String validatorObject = validator.isDiBean
                ? injectDependency(validatorClass).name
                : "ctx.getBean(${validatorClass.className}.class)"
        validator.method ? "${validatorObject}::${validator.method}" : validatorObject
    }

    class DictionaryConstraintResolver {

        String resolve(DictionaryConstraint constraint) {
            def result = getClass(VALIDATION_CHECKER).withGeneric(T_STRING)
            def params = [makeEnumParam(VALIDATION_DICTIONARY_CODES_ENUM, constraint.dictionaryCode, false)]
            if (constraint.invalidCode) {
                params.add(makeEnumParam(VALIDATION_INVALID_CODES_ENUM, constraint.invalidCode, true))
            }
            makeTransformation(result, params)
        }

        private FieldCm makeEnumParam(def enumType, String value, boolean withValue) {
            def enumClass = getClassCm(enumType)
            addImport(enumClass)
            def enumItem = new EnumItemCm(name: BaseEnumBuildStrategy.toEnumItemName(value))
            if (withValue) {
                enumItem.value = ValueCm.escaped(value)
            }
            enumClass.addEnumItem(enumItem)
            new FieldCm(enumClass).assign("${enumClass.className}.${enumItem.name}")
        }

    }

    static class ValidationDirectives {
        static final String DIRECTIVE_FIELD = 'field'
        static final String DIRECTIVE_BEAN = 'bean'
        static final String DIRECTIVE_DI_BEAN = 'di-bean'
        static final List<String> SUPPORTED_DIRECTIVES = [DIRECTIVE_FIELD, DIRECTIVE_DI_BEAN, DIRECTIVE_BEAN]

        List<DataConstraint> directives = []

        static ValidationDirectives forSchema(DataSchema schema) {
            def result = new ValidationDirectives()
            result.loadXValidations(schema)
            result.directives.addAll(schema.constraints.constraints)
            result
        }

        private void loadXValidations(DataSchema schema) {
            def xValidations = schema.additives['x-validations'] as List
            if (xValidations) {
                xValidations.eachWithIndex { it, idx ->
                    if (it instanceof String) {
                        directives.add(new ValidationDirective(
                                pointer: ValidatorPointer.fromExpression(it),
                                priority: 100 + idx
                        ))
                    } else if (it instanceof Map && isValidDirective(it)) {
                        def d = it.entrySet().find { SUPPORTED_DIRECTIVES.contains(it.key) }
                        directives.add(new ValidationDirective(
                                pointer: ValidatorPointer.fromExpression("$d.key:$d.value"),
                                priority: it.priority != null ? it.priority as Integer : (100 + idx)
                        ))
                    } else {
                        log.error("Can't handle 'x-validation': $it")
                    }
                }
            }
        }

        private boolean isValidDirective(Map spec) {
            1 == SUPPORTED_DIRECTIVES.findAll { spec.containsKey(it) }.size()
        }

        List<DataConstraint> getHighPriority() {
            directives.findAll { it.priority <= 0 }.sort { it.priority }
        }

        List<DataConstraint> getLowPriority() {
            directives.findAll { it.priority > 0 }.sort { it.priority }
        }
    }

    static class ValidationDirective extends DataConstraint {
        ValidatorPointer pointer
    }
}
