package de.monochromata.cucumber.stepdefs;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import javax.tools.ToolProvider;

import de.monochromata.cucumber.stepdefs.compiler.InMemoryCompilerOutput;
import de.monochromata.cucumber.stepdefs.compiler.InMemoryCompilerSource;
import de.monochromata.cucumber.stepdefs.compiler.InMemoryOutputFileManager;
import io.cucumber.docstring.DocString;
import io.cucumber.java.en.Given;
import io.cucumber.java.en.When;

@SuppressWarnings({ "rawtypes", "unchecked" })
public class JavaCompilerStepdefs {

    private final JavaCompilerState<Object> state;
    private final ExceptionState exceptionState;

    public JavaCompilerStepdefs(
            final JavaCompilerState state, 
            final ExceptionState exceptionState) {
        this.state = state;
        this.exceptionState = exceptionState;
    }
    
    @Given("a class {string} from source:")
    public void aClassFromSource(final String className, final DocString javaSource) {
        try {
            state.clazz = compileClass(className, javaSource.getContent(), true);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to compile/load class " + className + ", see standard error", e);
            }
        }
    }
    
    @Given("classes {string} from source:")
    public void classesFromSource(final String commaSeparatedClassNames, final DocString sources) {
        try {
            final var classNames = commaSeparatedClassNames.split(",");
            final var documents = sources.getContent().split("---");
            state.classes = compileClasses(classNames, documents, true);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to compile/load classes " + commaSeparatedClassNames + ", see standard error", e);
            }
        }
    }
    
    @Given("a class {string} from source defined by a class loader that does not delegate to its parent:")
    public void aClassFromSourceOSGI(final String className, final DocString javaSource) {
        try {
            state.clazz = compileClass(className, javaSource.getContent(), false);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to compile/load class "+className+", see standard error", e);
            }
        }
    }
    
    @When("an instance of the class is created")
    public void anInstanceOfTheClassIsCreated() {
        try {
            state.instance = state.clazz.getDeclaredConstructor().newInstance();
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to instantiate class via no-args constructor", e);
            }
        }
    }
    
    @When("an instance of {string} is created")
    public void anInstanceIsCreated(final String typeName) {
        try {
            final var clazz = state.classes.get(typeName);
            final var instance = clazz.getDeclaredConstructor().newInstance();
            state.instances.put(typeName, instance);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to instantiate class via no-args constructor", e);
            }
        }
    }
    
    protected Class compileClass(
            final String className, 
            final String javaSource, 
            final boolean delegateToParentClassLoader) throws IOException {
        final var classes = compileClasses(new String[] { className }, new String[] { javaSource }, delegateToParentClassLoader);
        return classes.values().iterator().next();
    }

    protected Map<String,Class<?>> compileClasses(
            final String[] classNames,
            final String[] sources, 
            final boolean delegateToParentClassLoader) throws IOException {
        final var compiler = ToolProvider.getSystemJavaCompiler();
        final var fileManager = new InMemoryOutputFileManager(compiler.getStandardFileManager(null, null, null));
        final var javaFileObjects = javaFileObjects(classNames, sources);
        final var successful = compiler.getTask(null, fileManager, null, null, null, javaFileObjects).call();
        if(!successful) {
            throw new RuntimeException("Compilation failed");
        }
        return defineClasses(fileManager.outputs, delegateToParentClassLoader);
    }
    
    protected List<InMemoryCompilerSource> javaFileObjects(
            final String[] classNames,
            final String[] sources) {
        if(classNames.length != sources.length) {
            throw new IllegalArgumentException("Mismatch: you defined " + classNames.length 
                    + " class name(s) but the source file is divided into " + sources.length + " source fragment(s)"
                    + " - they should match");
        }
        final List<InMemoryCompilerSource> javaFileObjects = new ArrayList<>(classNames.length);
        for(int i=0;i<classNames.length;i++) {
            javaFileObjects.add(new InMemoryCompilerSource(classNames[i], sources[i]));
        }
        return javaFileObjects;        
    }

    protected Map<String,Class<?>> defineClasses(final Map<String,InMemoryCompilerOutput> outputs, 
            final boolean delegateToParentClassLoader) {
        if(delegateToParentClassLoader) {
		    return DefiningClassLoader.instanceDelegatingToParent(outputs).definedClasses;
		}
        return DefiningClassLoader.instanceNotDelegatingToParent(outputs).definedClasses;
    }
    
	protected static class DefiningClassLoader extends ClassLoader {
		
		protected final Map<String,Class<?>> definedClasses = new HashMap<>();
		
		protected DefiningClassLoader(
		        final ClassLoader parent,
		        final Map<String,InMemoryCompilerOutput> outputs) {
		    super(null, parent);
		    for(final Entry<String,InMemoryCompilerOutput> nameAndOutput: outputs.entrySet()) {
		        final var className = nameAndOutput.getKey();
		        final var bytes = nameAndOutput.getValue().getClassData();
		        final var clazz = defineClass(className, bytes, 0, bytes.length);
		        definedClasses.put(className, clazz);
		    }
		}
		
		public static DefiningClassLoader instanceDelegatingToParent(
                final Map<String,InMemoryCompilerOutput> outputs) {
		    return new DefiningClassLoader(ClassLoader.getSystemClassLoader(), outputs);
		}
		
		public static DefiningClassLoader instanceNotDelegatingToParent(
                final Map<String,InMemoryCompilerOutput> outputs) {
            return new DefiningClassLoader(null, outputs);
		}
	}
    
}
