/*
 * Copyright © 2016-2024 Lightbend, Inc. All rights reserved.
 * No information contained herein may be reproduced or transmitted in any form
 * or by any means without the express written permission of Lightbend, Inc.
 */

package com.lightbend.tools.fortify.plugin

import scala.reflect.internal.Flags
import scala.tools.nsc

import com.fortify.frontend.nst

import nst._
import nodes._

// This is where we put pieces of Translator that are stateless
// and context-independent, yet Scala-version-specific.

trait TranslatorHelpers[T <: nsc.ast.Trees with nsc.symtab.SymbolTable]
    extends TranslatorBase
    with VersionSpecificHelpers[T]
    with Positions {

  val global: T
  import global._
  type Symbol = global.Symbol
  type Tree = global.Tree

  def seen(symbol: Symbol): Unit

  def toClassDecl(symbol: Symbol): STClassDecl = {
    seen(symbol)
    val result = new STClassDecl(symbol.pos)
    result.setSimpleName(symbol.name.decoded)
    result.setName(className(symbol))
    if (!symbol.owner.isEmptyPackageClass)
      result.setNamespace(symbol.owner.fullName)
    if (symbol.isInterface)
      result.addModifiers(NSTModifiers.Interface)
    else if (isAbstract(symbol))
      result.addModifiers(NSTModifiers.Abstract)
    if (isSyntheticClass(symbol))
      result.addModifiers(NSTModifiers.Synthetic)
    if (hasJavaEnumFlag(symbol)) {
      // emit same modifiers Java translator emits
      result.addModifiers(NSTModifiers.Enum)
      result.addModifiers(NSTModifiers.Final)
      result.addModifiers(NSTModifiers.Static)
    }
    result.addModifiers(NSTModifiers.Public)
    for (parent <- symbol.parentSymbols if parent != definitions.AnyClass)
      result.addExtends(typeForSymbol(parent, ref = false))
    result
  }

  def toFunDecl(symbol: Symbol, params: List[Symbol]): STFunDecl = {
    seen(symbol)
    val result = new STFunDecl(symbol.pos)
    result.setName(methodName(symbol))
    result.setSimpleName(toSimpleName(symbol))
    if (!symbol.isStaticMember) {
      val xthis =
        new STVarDecl(symbol.owner.pos, "this~", typeForSymbol(symbol.owner))
      result.addParameter(xthis)
    }
    for (param <- params)
      result.addParameter(
        new STVarDecl(param.pos, variableName(param, uniquify = false), typeForType(param.tpe)))
    val returnTypeSymbol = symbol.tpe.resultType.typeSymbol
    // in NST constructors have void return type
    val isVoid =
      symbol.isConstructor ||
        returnTypeSymbol == definitions.UnitClass
    result.setReturnType(
      if (isVoid)
        VoidType
      else if (symbol == definitions.Object_synchronized)
        new STType.STAnyType
      else if (symbol == definitions.Object_asInstanceOf)
        new STType.STAnyType
      else if (symbol == definitions.Object_isInstanceOf)
        STType.makePrimitiveBoolean(new SourceInfo)
      else
        typeForType(symbol.tpe.resultType))
    result.setModifiers(NSTModifiers.Public)
    if (isAbstract(symbol) && !isLateDeferred(symbol)) // latter via VersionSpecificHelpers
      result.setModifiers(NSTModifiers.Abstract)
    if (symbol.isStaticMember)
      result.setModifiers(NSTModifiers.Static)
    if (isSyntheticMethod(symbol))
      result.setModifiers(NSTModifiers.Synthetic)
    for (over <- overrides(symbol))
      result.addOverride(methodName(over))
    result
  }

  def paramNames(dd: DefDef): List[Symbol] =
    dd.vparamss.flatten.map(_.symbol)

  def overrides(symbol: Symbol): List[Symbol] =
    if (isImplClass(symbol.owner)) // via VersionSpecificHelpers
      Nil
    else if (symbol == definitions.Object_equals)
      // special case: don't let a `scala.Any` parent get in,
      // we generally try to keep `Any` out of the closure, especially
      // when it's clearly Java stuff like this
      Nil
    else
      global
        // gross typecast, but pulling in all of Global above just
        // for this one thing would be gross too.. :-\
        .asInstanceOf[nsc.Global]
        // we need to time travel here, because eventually this information gets lost. we had
        // `exitingTyper` but that turned out to be too far back in some cases, apparently because
        // when multiple parameter lists get flattened by uncurry, it makes new symbols? so let's
        // not time travel quite so far. see lightbend/scala-fortify#297
        .exitingUncurry(symbol.overrides)
        // further keeping `Any` out of the closure
        .filterNot(_ == definitions.Any_equals)

  def addModuleInit(classDef: ClassDef, clazz: STClassDecl): Unit = {
    clazz.addField {
      val field = new STFieldDecl(classDef.pos)
      field.setName("MODULE$")
      field.setType(typeForSymbol(classDef.symbol, ref = true))
      field.setModifiers(NSTModifiers.Public)
      field.setModifiers(NSTModifiers.Static)
      field.setModifiers(NSTModifiers.Synthetic)
      field
    }
    val staticInit = new STFunDecl(classDef.pos)
    staticInit.setModifiers(NSTModifiers.Public)
    staticInit.setModifiers(NSTModifiers.Static)
    staticInit.setReturnType(VoidType)
    val squig = squiggles(className(classDef.symbol))
    staticInit.setName(s"$squig~~<static>~S~")
    staticInit.setSimpleName("<static>")
    val body = new STBlock(classDef.pos)
    val assign = new STAssignmentStmt(classDef.pos)
    val access =
      new STStaticFieldAccess(classDef.pos, "MODULE$", typeForSymbol(classDef.symbol, ref = false))
    assign.setLeft(access)
    assign.setRight {
      val rhs = new STAllocation(classDef.pos)
      rhs.setType(typeForSymbol(classDef.symbol, ref = false))
      rhs
    }
    body.add(assign)
    body.add {
      val call = new STFunCall(classDef.pos)
      call.setName(s"$squig~~init^~L$squig^")
      call.addArgument(access)
      call
    }
    staticInit.setBody(body)
    clazz.addFunction(staticInit)
  }

  def toSimpleName(symbol: Symbol): String =
    if (symbol.isConstructor)
      symbol.owner.name.decoded
    else if (symbol.isLifted && !isSyntheticMethod(symbol) && !symbol.simpleName
        .startsWith(nme.ANON_FUN_NAME))
      symbol.name.decoded
        .replaceFirst("\\$\\d+$", "") // undo LambdaLift.renameSym for local defs
    else
      symbol.name.decoded

  def isSyntheticClass(symbol: Symbol): Boolean =
    symbol.isSynthetic &&
      // there's user code in these!
      !symbol.isAnonymousFunction

  def isSyntheticMethod(symbol: Symbol): Boolean =
    (symbol.isSynthetic &&
      !symbol.name.startsWith("delayedEndpoint$")) || // these contain actual user code
      symbol.isSpecialized ||
      symbol.isAccessor ||
      symbol.isParamAccessor ||
      (symbol.isConstructor && symbol.owner.isSynthetic) ||
      symbol.name.endsWith("$extension")

  def toFieldDecl(symbol: Symbol): STFieldDecl = {
    val result = new STFieldDecl(symbol.pos)
    result.setName(variableName(symbol, uniquify = false))
    result.setType(typeForType(symbol.tpe, ref = true))
    // we don't want to mark `private[this]` fields as synthetic, since they
    // have no accessors, there is nothing but the field.  so that's the reason
    // we call getterIn here
    if (!symbol.isJavaDefined && (symbol.isImplementationArtifact || symbol
        .getterIn(symbol.owner) != NoSymbol))
      result.setModifiers(NSTModifiers.Synthetic)
    result.setModifiers(
      if (symbol.isPublic)
        NSTModifiers.Public
      else
        NSTModifiers.Private)
    if (symbol.isStaticMember)
      result.setModifiers(NSTModifiers.Static)
    result
  }

  def methodName(symbol: Symbol): String =
    // special case because 2.13 gives the expected name
    // but 2.11 and 2.12 have `scala~Any` as the third type.
    // the 2.13 one seems right to me, and idk if SCA has special handling
    // of `equals` but if it does, matching Java here should help it
    if (symbol eq definitions.Object_equals)
      "java~lang~Object~~equals~Ljava~lang~Object^Ljava~lang~Object^"
    else {
      seen(symbol)
      seen(symbol.owner)
      val clazz = squiggles(className(symbol.owner))
      val name =
        if (symbol.isConstructor) "init^"
        else unDollar(symbol.name.toString)
      val extra =
        if (symbol.isStaticMember)
          "S~"
        else
          s"L$clazz^"
      val argTypes =
        if (symbol.name.toString.matches(""".*\$extension\d*"""))
          // for reasons that utterly baffle me, `paramTypes` was giving me different
          // results when running the actual plugin, than when running EndToEndTests
          // here in-repo. (I have never experienced that before in all my years
          // of working on this project!) so let's get the types a different way.
          // it's an open question what the right way is to translate
          // calls to `implicit class` based extension methods, and the way we are
          // currently doing it isn't necessarily right. but this is the
          // narrowest fix I could come up with to eliminate the behavior difference
          // between the two repos without altering existing test outputs
          // in the test repo. if it proves difficult to align the Scala 3 behavior
          // with this I could rethink, otherwise :shrug: I guess. as usual, if
          // it seems unlikely to affect vuln detection it isn't worth pouring too
          // much time into.
          global.asInstanceOf[nsc.Global]
            .exitingErasure(symbol.info.paramss.flatten.map(_.tpe))
        else
          symbol.info.paramTypes
      val argTypeString =
        argTypes
          .map(typeString)
          .map(unDollar)
          .mkString
      s"$clazz~~$name~$extra$argTypeString"
    }

  private val uniqueNames = collection.mutable.Map[Symbol, String]()
  private val usedNames = collection.mutable.Set[String]()

  def resetLocalNames(): Unit = {
    uniqueNames.clear()
    usedNames.clear()
  }

  def uniquifyVariable(sym: Symbol): String =
    uniqueNames.getOrElse(
      sym, {
        val originalName = sym.name.toString
        val newName =
          if (usedNames.contains(originalName)) {
            var index = 1
            while (usedNames.contains(s"$originalName~$index")) index += 1
            s"$originalName~$index"
          }
          else originalName
        usedNames += newName
        uniqueNames(sym) = newName
        newName
      }
    )

  def variableName(symbol: Symbol, uniquify: Boolean = false): String = {
    // for field names scalac is somehow coming up with a name
    // with a space the end, I guess to distinguish it from the
    // same-named accessor?
    def stripSpaceFromFieldName(s: String) =
      s.replaceFirst(" $", "")
    def baseName =
      if (isOuterParam(symbol))
        "@outer" // otherwise we end up with "arg@outer"
      else if (uniquify)
        uniquifyVariable(symbol)
      else
        symbol.name.toString
    val result = stripSpaceFromFieldName(unDollar(baseName))
    if (symbol.isSynthetic || symbol.isArtifact)
      s"~t~$result"
    else
      result
  }

  // copied from Scala 2.12 sources
  def isOuterParam(symbol: Symbol): Boolean =
    symbol.isParameter && symbol.owner.isConstructor && (symbol.name == TermName(
      "arg" + nme.OUTER) || symbol.name == nme.OUTER)

  def typeForSymbol(symbol: Symbol, ref: Boolean = true): STType = {
    seen(symbol)
    symbol match {
      case definitions.UnitClass =>
        STType.makePrimitiveVoid(new SourceInfo)
      case definitions.IntClass =>
        STType.makePrimitiveInt(new SourceInfo)
      case definitions.BooleanClass =>
        STType.makePrimitiveBoolean(new SourceInfo)
      case definitions.FloatClass =>
        STType.makePrimitiveFloat(new SourceInfo)
      case definitions.DoubleClass =>
        STType.makePrimitiveDouble(new SourceInfo)
      case definitions.CharClass =>
        STType.makePrimitiveChar(new SourceInfo)
      case definitions.ByteClass =>
        STType.makePrimitiveByte(new SourceInfo)
      case definitions.ShortClass =>
        STType.makePrimitiveShort(new SourceInfo)
      case definitions.LongClass =>
        STType.makePrimitiveLong(new SourceInfo)
      case _ =>
        val classType = new STType.STClassType(symbol.pos, className(symbol))
        if (ref)
          new STType.STPointerType(symbol.pos, classType)
        else
          classType
    }
  }

  // use this instead of typeForSymbol in more places, maybe eliminate
  // typeForSymbol entirely? be careful about "seen"
  def typeForType(tpe: Type, ref: Boolean = true): STType = {
    seen(tpe.typeSymbol)
    tpe.typeSymbol match {
      case definitions.ArrayClass =>
        val arr =
          new STType.STArrayType(tpe.typeSymbol.pos, typeForType(tpe.typeArgs.head, ref = true))
        if (ref)
          new STType.STPointerType(tpe.typeSymbol.pos, arr)
        else
          arr
      case sym =>
        typeForSymbol(sym, ref)
    }
  }

  def typeString(tpe: Type): String = {
    seen(tpe.typeSymbol)
    tpe match {
      case ErasedValueType(_, underlying) =>
        typeString(underlying)
      case _ =>
        tpe.typeSymbol match {
          case definitions.IntClass =>
            "I^"
          case definitions.BooleanClass =>
            "Z^"
          case definitions.CharClass =>
            "C^"
          case definitions.LongClass =>
            "L^"
          case definitions.FloatClass =>
            "F^"
          case definitions.DoubleClass =>
            "D^"
          case definitions.ShortClass =>
            "S^"
          case definitions.ByteClass =>
            "B^"
          case definitions.ArrayClass =>
            s"?${typeString(tpe.typeArgs.head)}"
          case definitions.AnyClass =>
            s"Ljava~lang~Object^"
          case _ =>
            s"L${squiggles(className(tpe.typeSymbol))}^"
        }
    }
  }

  def className(symbol: Symbol): String =
    symbol.fullName +
      (if (symbol.isModuleClass && !symbol.isJavaDefined) "$" else "")

  /// other helpers

  def isPatternMatch(block: Block): Boolean = {
    val Block(stats, _) = block
    stats
      .collectFirst { case label: LabelDef => label }
      .exists(treeInfo.hasSynthCaseSymbol)
  }

  private lazy val numericTypes =
    Set[Symbol](
      definitions.IntClass, definitions.DoubleClass, definitions.FloatClass, definitions.ShortClass,
      definitions.ByteClass, definitions.LongClass, definitions.CharClass
    )

  def isNumericType(tpe: Type): Boolean =
    numericTypes(tpe.typeSymbol)

  def isReferenceType(tpe: Type): Boolean =
    tpe <:< definitions.AnyRefTpe

  def isBooleanType(tpe: Type): Boolean =
    tpe.typeSymbol match {
      case definitions.BooleanClass =>
        true
      case _ =>
        false
    }

  def isArrayType(tpe: Type): Boolean =
    tpe.typeSymbol match {
      case definitions.ArrayClass =>
        true
      case _ =>
        false
    }

  def isAbstract(sym: Symbol): Boolean =
    sym.isAbstract ||
      sym.hasFlag(Flags.SUPERACCESSOR) ||
      // via VersionSpecificHelpers
      hasSynthesizeImplInSubclassFlag(sym)

  // 2.12 has this, 2.11 doesn't, so inline the 2.11 definition
  def isField(s: Symbol): Boolean = {
    import s._
    isTerm && !isModule && (!isMethod || owner.isTrait && isAccessor)
  }

  def shouldTranslate(dd: DefDef): Boolean =
    dd.symbol.isOuterAccessor ||
      dd.symbol.isSetter ||
      definitions.isJavaMainMethod(dd.symbol) ||
      (!isDelambdafyTarget(dd.symbol) &&
        !dd.symbol.hasFlag(Flags.BRIDGE))

  def synthesizeMain(classDef: ClassDef): Option[STClassDecl] = {
    val mainOption =
      classDef.symbol.tpe
        .member(nme.main)
        .alternatives
        .find(definitions.isJavaMainMethod)
    mainOption.map { main =>
      val clazz = toClassDecl(classDef.symbol.companionSymbol)
      clazz.addFunction {
        val fun = new STFunDecl
        fun.setName {
          val clazzName = squiggles(className(classDef.symbol.companionSymbol))
          s"$clazzName~~main~S~~?Ljava~lang~String^"
        }
        fun.addParameter(
          new STVarDecl(
            main.pos,
            "args",
            new STType.STPointerType(
              main.pos,
              new STType.STArrayType(
                main.pos,
                new STType.STPointerType(
                  main.pos,
                  new STType.STClassType(className(definitions.StringClass)))))
          ))
        fun.setSimpleName("main")
        fun.setReturnType(VoidType)
        fun.setModifiers(NSTModifiers.Public, NSTModifiers.Static, NSTModifiers.Synthetic)
        fun.setBody {
          val block = new STBlock(main.pos)
          block.add {
            val call = new STFunCall(main.pos)
            call.setVirtual()
            call.setName(methodName(main))
            seen(classDef.symbol)
            call.addArgument(
              new STStaticFieldAccess(
                main.pos,
                "MODULE$",
                new STType.STClassType(main.pos, classDef.symbol.fullName + "$")))
            call.addArgument(new STVarAccess(main.pos, "args"))
            call
          }
          block
        }
        fun
      }
      clazz
    }
  }

  // cribbed from Scala.js
  object WrapArray {
    import definitions.{getMemberMethod, PredefModule}
    lazy val isWrapArray: Set[Symbol] = Seq(
      nme.wrapRefArray, nme.wrapByteArray, nme.wrapShortArray, nme.wrapCharArray, nme.wrapIntArray,
      nme.wrapLongArray, nme.wrapFloatArray, nme.wrapDoubleArray, nme.wrapBooleanArray,
      nme.wrapUnitArray, nme.genericWrapArray
    ).map(getMemberMethod(PredefModule, _)).toSet
    def unapply(tree: Apply): Option[Tree] = tree match {
      case Apply(wrapArray_?, List(wrapped)) if isWrapArray(wrapArray_?.symbol) =>
        Some(wrapped)
      case _ =>
        None
    }
  }

  /// operators

  val referenceOps: Map[String, NSTOperators] =
    Map(
      "eq" -> NSTOperators.Equal,
      "neq" -> NSTOperators.NotEqual
    )

  val arithmeticOps: Map[String, NSTOperators] =
    Map(
      "$plus" -> NSTOperators.Add,
      "$minus" -> NSTOperators.Subtract,
      "$times" -> NSTOperators.Multiply,
      "$div" -> NSTOperators.Divide,
      "$percent" -> NSTOperators.Mod,
      "$amp" -> NSTOperators.And,
      "$bar" -> NSTOperators.Or,
      "$up" -> NSTOperators.Xor,
      "$eq$eq" -> NSTOperators.Equal,
      "$bang$eq" -> NSTOperators.NotEqual,
      "$less" -> NSTOperators.LessThan,
      "$less$eq" -> NSTOperators.LessThanEqual,
      "$greater" -> NSTOperators.GreaterThan,
      "$greater$eq" -> NSTOperators.GreaterThanEqual,
      "$less$less" -> NSTOperators.Lshift,
      "$greater$greater" -> NSTOperators.Rshift,
      "$greater$greater$greater" -> NSTOperators.Bwrshift
    )

  def conversionOps: Set[String] =
    Set("toDouble", "toLong", "toShort", "toByte", "toFloat", "toChar", "toInt")

}
