/*
 * Copyright © 2016-2024 Lightbend, Inc. All rights reserved.
 * No information contained herein may be reproduced or transmitted in any form
 * or by any means without the express written permission of Lightbend, Inc.
 */

package com.lightbend.tools.fortify.plugin

import dotty.tools.dotc.plugins.PluginPhase
import java.io.File

class FortifyPlugin extends VersionSpecificPlugin:
  override val name = "fortify"
  override val description = "compile Scala to Fortify NST"
  override val optionsHelp: Option[String] = Some(PluginOptions.HelpString)
  override def init(options: List[String], error: String => Unit) =
    // the phase must be freshly constructed inside `init`,
    // as per lampepfl/dotty#18381
    val phase = FortifyPhase()
    import PluginOptions.RegexInterpolator
    options.foreach:
      case r"build=(.*)$id" =>
        phase.buildId = Some(id)
      case r"scaversion=(.*)$version" =>
        phase.scaVersion = version
      case r"out=(.*)$out" =>
        phase.outputDir = Some(File(out))
      case r"exclude=(.*)$specs" =>
        phase.excludes ++= Paths.parseExcludes(specs)
      // remaining options are undocumented
      case r"showSourceInfo=(.*)$show" =>
        phase.showSourceInfo = show.toBoolean
      case r"suppressEntitlementCheck=(.*)$suppress" =>
        phase.suppressEntitlementCheck = suppress.toBoolean
      case r"trace=(.*)$trace" =>
        phase.trace = trace.toBoolean
      case arg =>
        error(s"Bad argument: $arg")
    for
      id <- phase.buildId
      if !phase.outputDir.isDefined
    do
      phase.outputDir =
        Some(PluginOptions.pathForBuildId(phase.scaVersion, id))
    LicenseChecker.verifyEntitlements(
      phase.licensePath,
      phase.suppressEntitlementCheck,
      println)
    for
      id <- phase.buildId
      dir <- phase.outputDir
    do
      dir.mkdirs()
      java.io.File(dir.getParent, s"$id.scasession.lock")
        .createNewFile()
    List(phase)

import dotty.tools.dotc
import dotc.core.Contexts.Context
import dotc.ast.tpd.Tree

class FortifyPhase extends PluginPhase with PluginOptions:
  override val phaseName = "fortify"
  override val runsAfter = Set(dotty.tools.dotc.transform.MoveStatics.name)
  override val runsBefore = Set(dotty.tools.backend.jvm.GenBCode.name)
  val session = collection.mutable.Buffer[SessionWriter.Entry]()
  override def run(using Context): Unit =
    super.run
    if (session.isEmpty) // e.g. in REPL
      return
    import java.io.{File, FileOutputStream, OutputStreamWriter, PrintWriter}
    import java.nio.charset.StandardCharsets
    for (id <- buildId; dir <- outputDir) do
      SessionWriter.synchronized:
        val sessionFile =
          new File(dir.getParent, s"$id.scasession.increment")
        // maybe a bit janky to reuse this flag for this, but we don't want this
        // noise when doing development work
        if !suppressEntitlementCheck then
          println(s"scala-fortify: writing translated files to ${dir.getAbsolutePath}")
        val exists = sessionFile.exists
        val foStream = new FileOutputStream(sessionFile, true) // append = true
        val osWriter = new OutputStreamWriter(foStream, StandardCharsets.UTF_8)
        val writer = new PrintWriter(osWriter)
        // not sure if it's important to match the Java translator's behavior
        // of starting a new file with a blank line, but we might as well
        if (!exists)
          writer.println()
        try SessionWriter.write(
            writer,
            buildId.getOrElse(""),
            session.toSeq,
            summon[Context].settings.encoding.value)
        finally writer.close()
  override def transformUnit(tree: Tree)(using Context): Tree =
    if !Paths.isExcluded(excludes, tree.source.file.file) then
      val translator =
        if trace
        then TracingTranslator()
        else Translator()
      val actualOutputDir = outputDir.getOrElse(new File("."))
      val outputPath =
        Paths.sourcePathToNstPath(tree.source.file.file, actualOutputDir)
      outputPath.getParentFile.mkdirs()
      val writer = java.io.PrintWriter(outputPath)
      session += SessionWriter.Entry(
        tree.source.file.file, outputPath, countLines(tree.source.file.file))
      val (originalPath, lineMapper) =
        val twirl = Twirl(tree.source.content.mkString)
        twirl.originalPath match
          case Some(path) =>
            (path, twirl.mapLine)
          case None =>
            (tree.source.file.path, identity[Int])
      try
        translator.apply(
          path = tree.source.file.path,
          source = tree.source,
          tree = tree,
          lineMapper = identity,
          out = writer
        )
      finally writer.close()
    end if
    tree
  private def countLines(file: File)(using Context): Int =
    val source = io.Source.fromFile(file)(summon[Context].settings.encoding.value)
    try source.getLines().size
    finally source.close()
