package de.monochromata.cucumber.stepdefs;

import static java.util.stream.Collectors.toMap;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.AbstractMap;
import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Collectors;

import javax.tools.ToolProvider;

import de.monochromata.cucumber.stepdefs.compiler.InMemoryCompilerOutput;
import de.monochromata.cucumber.stepdefs.compiler.InMemoryCompilerSource;
import de.monochromata.cucumber.stepdefs.compiler.InMemoryOutputFileManager;
import io.cucumber.docstring.DocString;
import io.cucumber.java.en.Given;
import io.cucumber.java.en.When;

@SuppressWarnings({ "rawtypes", "unchecked" })
public class JavaCompilerStepdefs {

    private final JavaCompilerState<Object> state;
    private final ExceptionState exceptionState;

    public JavaCompilerStepdefs(
            final JavaCompilerState state, 
            final ExceptionState exceptionState) {
        this.state = state;
        this.exceptionState = exceptionState;
    }
    
    @Given("a class {string} from source:")
    public void aClassFromSource(final String className, final DocString javaSource) {
        try {
            state.clazz = compileClass(className, javaSource.getContent(), true);
            state.classes = new HashMap<>();
            state.classes.put(className, state.clazz);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to compile/load class " + className + ", see standard error", e);
            }
        }
    }
    
    @Given("classes {string} from source:")
    public void classesFromSource(final String commaSeparatedClassNames, final DocString sources) {
        try {
            final var classNames = commaSeparatedClassNames.split(",");
            final var documents = sources.getContent().split("---");
            state.classes = compileClasses(classNames, documents, true);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to compile/load classes " + commaSeparatedClassNames + ", see standard error", e);
            }
        }
    }
    
    @Given("a class {string} from source defined by a class loader that does not delegate to its parent:")
    public void aClassFromSourceOSGI(final String className, final DocString javaSource) {
        try {
            state.clazz = compileClass(className, javaSource.getContent(), false);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to compile/load class "+className+", see standard error", e);
            }
        }
    }
    
    @Given("an instance of the class")
    @When("an instance of the class is created")
    public void anInstanceOfTheClassIsCreated() {
        try {
            state.instance = state.clazz.getDeclaredConstructor().newInstance();
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to instantiate class via no-args constructor", e);
            }
        }
    }

    @Given("an instance of {string}")
    @When("an instance of {string} is created")
    public void anInstanceIsCreated(final String typeName) {
        try {
            final var clazz = state.classes.get(typeName);
            final var instance = clazz.getDeclaredConstructor().newInstance();
            state.instances.put(typeName, instance);
        } catch (final Exception e) {
            if(state.catchExceptionsForAssertions) {
                exceptionState.exception = e;
            } else {
                throw new RuntimeException("Failed to instantiate class via no-args constructor", e);
            }
        }
    }
    
    protected Class compileClass(
            final String className, 
            final String javaSource, 
            final boolean delegateToParentClassLoader) throws IOException {
        final var classes = compileClasses(new String[] { className }, new String[] { javaSource }, delegateToParentClassLoader);
        return classes.values().iterator().next();
    }

    protected Map<String,Class<?>> compileClasses(
            final String[] classNames,
            final String[] sources, 
            final boolean delegateToParentClassLoader) throws IOException {
        final var compiler = ToolProvider.getSystemJavaCompiler();
        final var fileManager = new InMemoryOutputFileManager(compiler.getStandardFileManager(null, null, null));
        final var javaFileObjects = javaFileObjects(classNames, sources);
        final var successful = compiler.getTask(null, fileManager, null, null, null, javaFileObjects).call();
        if(!successful) {
            throw new RuntimeException("Compilation failed");
        }
        return defineClasses(fileManager.outputs, delegateToParentClassLoader);
    }
    
    protected List<InMemoryCompilerSource> javaFileObjects(
            final String[] classNames,
            final String[] sources) {
        if(classNames.length != sources.length) {
            throw new IllegalArgumentException("Mismatch: you defined " + classNames.length 
                    + " class name(s) but the source file is divided into " + sources.length + " source fragment(s)"
                    + " - they should match");
        }
        final List<InMemoryCompilerSource> javaFileObjects = new ArrayList<>(classNames.length);
        for(int i=0;i<classNames.length;i++) {
            javaFileObjects.add(new InMemoryCompilerSource(classNames[i], sources[i]));
        }
        return javaFileObjects;
    }

    protected Map<String,Class<?>> defineClasses(final Map<String,InMemoryCompilerOutput> outputs, 
            final boolean delegateToParentClassLoader) {
        if(delegateToParentClassLoader) {
		    return DefiningClassLoader.instanceDelegatingToParent(outputs).loadAllClasses();
		}
        return DefiningClassLoader.instanceNotDelegatingToParent(outputs).loadAllClasses();
    }
    
	protected static class DefiningClassLoader extends ClassLoader {
		
	    private final Map<String,InMemoryCompilerOutput> outputs;
		private final Map<String,Class<?>> definedClasses = new HashMap<>();
		private final Map<String,byte[]> resources;
		
		protected DefiningClassLoader(
		        final ClassLoader parent,
		        final Map<String,InMemoryCompilerOutput> outputs) {
		    super(null, parent);
		    this.outputs = outputs;
		    this.resources = createResources(outputs);
		}
		
		protected static Map<String, byte[]> createResources(final Map<String, InMemoryCompilerOutput> outputs) {
		    return outputs.entrySet().stream()
		            .map(entry -> createResourceEntry(entry))
		            .collect(toMap(Entry::getKey, Entry::getValue));
        }
		
		protected static Entry<String,byte[]> createResourceEntry(Map.Entry<String,InMemoryCompilerOutput> entry) {
		    final var resourceName = entry.getKey().replace('.', '/') + ".class";
		    return new AbstractMap.SimpleImmutableEntry<>(resourceName, entry.getValue().getClassData());
		}

        public Map<String,Class<?>> loadAllClasses() {
		    return outputs.entrySet().stream()
		            .map(entry -> loadClassEntry(entry))
		            .collect(toMap(Entry::getKey, Entry::getValue));
		}

        protected Entry<String, Class<?>> loadClassEntry(Entry<String, InMemoryCompilerOutput> entry) {
            try {
                return new AbstractMap.SimpleEntry<>(entry.getKey(), loadClass(entry.getKey()));
            } catch (final ClassNotFoundException e) {
                throw new IllegalStateException("Could not load internally-defined class", e);
            }
        }
		
        @Override
        protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
            final var definedClass = definedClasses.get(name);
            if(definedClass != null) {
                return definedClass;
            }
            final var output = outputs.get(name);
            if(output != null) {
                return defineClass(name, output.getClassData());
            }
            return super.loadClass(name, resolve);
        }

        @Override
        public InputStream getResourceAsStream(final String name) {
            final var data = resources.get(name);
            if(data == null) {
                return null;
            }
            return new ByteArrayInputStream(data);
        }

        protected Class<?> defineClass(final String className, final byte[] classData) {
            final var clazz = defineClass(className, classData, 0, classData.length);
            definedClasses.put(className, clazz);
            return clazz;
        }

        public static DefiningClassLoader instanceDelegatingToParent(
                final Map<String,InMemoryCompilerOutput> outputs) {
		    return new DefiningClassLoader(ClassLoader.getSystemClassLoader(), outputs);
		}
		
		public static DefiningClassLoader instanceNotDelegatingToParent(
                final Map<String,InMemoryCompilerOutput> outputs) {
            return new DefiningClassLoader(null, outputs);
		}
	}
    
}
