/*
 * Copyright © 2016-2024 Lightbend, Inc. All rights reserved.
 * No information contained herein may be reproduced or transmitted in any form
 * or by any means without the express written permission of Lightbend, Inc.
 */

package com.lightbend.tools.fortify.plugin

import scala.language.implicitConversions

import com.fortify.frontend.nst
import nst.*
import nodes.*

import dotty.tools.dotc
import dotc.ast.tpd
import dotc.core.Contexts.{Context, atPhase}
import dotc.core.Phases.*
import dotc.core.Symbols.toDenot
import dotc.core.Symbols.NoSymbol
import dotc.core.Flags.*
import dotc.core.StdNames.*
import dotc.core.Types.Type
import dotc.core.Names.Name
import dotc.util.SourcePosition

import VersionSpecificHelpers.*

// This is where we put pieces of Translator that are stateless
// and context-independent.

trait TranslatorHelpers(using ctx: Context)
    extends TranslatorBase
    with Positions:

  type Symbol = dotc.core.Symbols.Symbol

  // accessors
  def defn = ctx.definitions

  def seen(symbol: Symbol): Unit

  def toClassDecl(symbol: Symbol): STClassDecl =

    seen(symbol)

    val result = new STClassDecl(symbol.sourcePos)
    result.setSimpleName:
      val s = symbol.name.toString
      if symbol.is(ModuleClass)
      then s.init  // drop $
      else s
    result.setName(className(symbol))
    result.addModifiers(NSTModifiers.Public)

    for
      parent <- symbol.info.parents.map(_.classSymbol)
        if parent != defn.AnyClass && parent != defn.MatchableClass
    do
      result.addExtends(typeForSymbol(parent, ref = false))

    if symbol.owner != defn.EmptyPackageClass then
      result.setNamespace(symbol.owner.fullName.toString)
    if symbol.is(PureInterface) || symbol.isAllOf(JavaInterface) then
      result.addModifiers(NSTModifiers.Interface)
    else if symbol.is(Abstract) && !hasJavaEnumFlag(symbol) then
      result.addModifiers(NSTModifiers.Abstract)
    if hasJavaEnumFlag(symbol) then
      result.addModifiers(NSTModifiers.Enum)
      result.addModifiers(NSTModifiers.Final)
      result.addModifiers(NSTModifiers.Static)
    end if
    result

  def className(symbol: Symbol): String =
    val s = symbol.fullName.mangledString
    if symbol.is(ModuleClass) && symbol.is(JavaDefined)
    then s.init  // drop $
    else s

  def toFunDecl(symbol: Symbol, paramSymbols: List[Symbol], paramNames: List[Name], paramPositions: List[SourcePosition], paramInfos: List[Type]): STFunDecl =
    import ctx.definitions
    seen(symbol)
    val result = new STFunDecl(symbol.sourcePos)
    result.setName(methodName(symbol))
    result.setSimpleName(toSimpleName(symbol))
    if !symbol.is(JavaDefined) || symbol.isConstructor || !symbol.isStatic || ctx.platform.isMainMethod(symbol) then
      val xthis = new STVarDecl(symbol.denot.owner.sourcePos, "this~",
        typeForSymbol(symbol.denot.owner))
      result.addParameter(xthis)
    for (symbol, param, position, info) <- paramSymbols.lazyZip(paramNames).lazyZip(paramPositions).lazyZip(paramInfos) do
      val vn = if symbol.exists then variableName(symbol, uniquify = false) else variableName(param.toString)
      result.addParameter(
        new STVarDecl(position, vn, typeForType(info)))
    val isVoid =
      symbol.info.resultType.typeSymbol == definitions.UnitClass ||
        // in NST constructors have void return type
        symbol.isConstructor
    result.setReturnType(
      if isVoid
      then VoidType
      else typeForType(symbol.info.resultType)
    )
    result.setModifiers(NSTModifiers.Public)
    if symbol.is(Deferred) then
      result.setModifiers(NSTModifiers.Abstract)
    if symbol.is(JavaDefined) && symbol.isStatic && !symbol.isConstructor && !ctx.platform.isMainMethod(symbol) then
      result.setModifiers(NSTModifiers.Static)
    import dotty.tools.dotc.core.Decorators.i
    if isSyntheticMethod(symbol) then
      result.setModifiers(NSTModifiers.Synthetic)
    for over <- overrides(symbol) do
      result.addOverride(methodName(over))
    result

  def paramNames(dd: tpd.DefDef): List[Symbol] =
    dd.paramss.flatten.map(_.symbol)

  def methodName(symbol: Symbol): String =
    if !symbol.exists then return "<DOES NOT EXIST>" // TODO
    seen(symbol)
    seen(symbol.owner)
    val clazz = squiggles(className(symbol.owner))
    val name =
      if symbol.isConstructor
      then "init^"
      else unDollar(symbol.name.mangledString)
    val extra =
      if symbol.is(JavaDefined) && symbol.isStatic && !symbol.isConstructor && !ctx.platform.isMainMethod(symbol)
      then "S~"
      else s"L$clazz^"
    val argTypes =
      symbol.info.firstParamTypes
        .map(typeString)
        .map(unDollar)
        .mkString
    s"$clazz~~$name~$extra$argTypes"

  def toSimpleName(symbol: Symbol): String =
    if symbol.isConstructor
    then symbol.denot.owner.showName
    else symbol.denot.name.decode.toString

  // we have to time travel or the information is missing, but I'm not sure
  // how far back to go. in Scala 2, we go back to uncurry. let's try this
  // and if we have to adjust it later, so be it.
  def overrides(symbol: Symbol): Iterator[Symbol] =
    atPhase(erasurePhase)(symbol.allOverriddenSymbols)

  def addModuleInit(classDef: tpd.TypeDef, clazz: STClassDecl): Unit =
    clazz.addField:
      val field = new STFieldDecl(classDef.sourcePos)
      field.setName("MODULE$")
      field.setType(typeForSymbol(classDef.symbol, ref = true))
      field.setModifiers(NSTModifiers.Public)
      field.setModifiers(NSTModifiers.Static)
      field.setModifiers(NSTModifiers.Synthetic)
      field
    val staticInit = new STFunDecl(classDef.sourcePos)
    staticInit.setModifiers(NSTModifiers.Public)
    staticInit.setModifiers(NSTModifiers.Static)
    staticInit.setReturnType(VoidType)
    val squig = squiggles(className(classDef.symbol))
    staticInit.setName(s"$squig~~<static>~S~")
    staticInit.setSimpleName("<static>")
    val body = new STBlock(classDef.sourcePos)
    val assign = new STAssignmentStmt(classDef.sourcePos)
    val access =
      new STStaticFieldAccess(classDef.sourcePos, "MODULE$", typeForSymbol(classDef.symbol, ref = false))
    assign.setLeft(access)
    assign.setRight:
      val rhs = new STAllocation(classDef.sourcePos)
      rhs.setType(typeForSymbol(classDef.symbol, ref = false))
      rhs
    body.add(assign)
    body.add:
      val call = new STFunCall(classDef.sourcePos)
      call.setName(s"$squig~~init^~L$squig^")
      call.addArgument(access)
      call
    staticInit.setBody(body)
    clazz.addFunction(staticInit)

  def isSyntheticMethod(symbol: Symbol): Boolean =
    (symbol.is(Synthetic) && !symbol.is(Lifted)) ||
      symbol.is(Accessor) ||
      symbol.is(ParamAccessor) ||
      atPhase(erasurePhase):
        symbol.isGetter

  def toFieldDecl(symbol: Symbol): STFieldDecl =
    val result = new STFieldDecl(symbol.sourcePos)
    result.setName(variableName(symbol, uniquify = false))
    result.setType(typeForType(symbol.info, ref = true))
    // we don't want to mark `private[this]` fields as synthetic, since they
    // have no accessors, there is nothing but the field. so that's the reason
    // we check is(Local)
    def isLocal =
      atPhase(typerPhase):
        !symbol.is(Local)
    if !symbol.is(JavaDefined) && (symbol.is(Artifact) || isLocal) then
      result.setModifiers(NSTModifiers.Synthetic)
    result.setModifiers(
      if symbol.isPublic
      then NSTModifiers.Public
      else NSTModifiers.Private)
    if symbol.is(JavaStatic) then
      result.setModifiers(NSTModifiers.Static)
    result

  private val uniqueNames = collection.mutable.Map[Symbol, String]()
  private val usedNames = collection.mutable.Set[String]()

  def resetLocalNames(): Unit =
    uniqueNames.clear()
    usedNames.clear()

  def uniquifyVariable(sym: Symbol): String =
    uniqueNames.getOrElse(
      sym, locally:
        val originalName = sym.name.toString
        val newName =
          if usedNames.contains(originalName)
          then
            var index = 1
            while usedNames.contains(s"$originalName~$index")
            do index += 1
            s"$originalName~$index"
          else originalName
        usedNames += newName
        uniqueNames(sym) = newName
        newName
    )

  def variableName(s: String): String =
    unDollar(s).replaceFirst(" $", "")

  def variableName(symbol: Symbol, uniquify: Boolean = false): String =
    def baseName =
      if uniquify
      then uniquifyVariable(symbol)
      else symbol.name.toString
    val result = variableName(baseName)
    if symbol.is(Synthetic) || symbol.is(Artifact)
    then s"~t~$result"
    else result

  def typeForSymbol(symbol: Symbol, ref: Boolean = true): STType =
    seen(symbol)
    val defn = ctx.definitions
    if symbol == defn.UnitClass then
      STType.makePrimitiveVoid(new SourceInfo)
    else if symbol == defn.IntClass then
      STType.makePrimitiveInt(new SourceInfo)
    else if symbol == defn.BooleanClass then
      STType.makePrimitiveBoolean(new SourceInfo)
    else if symbol == defn.FloatClass then
      STType.makePrimitiveFloat(new SourceInfo)
    else if symbol == defn.DoubleClass then
      STType.makePrimitiveDouble(new SourceInfo)
    else if symbol == defn.CharClass then
      STType.makePrimitiveChar(new SourceInfo)
    else if symbol == defn.ByteClass then
      STType.makePrimitiveByte(new SourceInfo)
    else if symbol == defn.ShortClass then
      STType.makePrimitiveShort(new SourceInfo)
    else if symbol == defn.LongClass then
      STType.makePrimitiveLong(new SourceInfo)
    else
      val classType =
        new STType.STClassType(symbol.sourcePos, className(symbol.info.widenDealias.typeSymbol))
      if ref then
        new STType.STPointerType(symbol.sourcePos, classType)
      else
        classType

  // use this instead of typeForSymbol in more places, maybe eliminate
  // typeForSymbol entirely? be careful about "seen"
  def typeForType(tpe: Type, ref: Boolean = true): STType =
    seen(tpe.typeSymbol)
    tpe.widenDealias match
      case defn.ArrayOf(elemType) =>
        val arr =
          new STType.STArrayType(tpe.typeSymbol.sourcePos, typeForType(elemType.widenDealias))
        if ref
        then new STType.STPointerType(tpe.typeSymbol.sourcePos, arr)
        else arr
      case t =>
        typeForSymbol(t.typeSymbol, ref)

  def typeString(tpe: Type): String =
    val sym = atPhase(typerPhase)(tpe.widenDealias).typeSymbol
    seen(sym)
    if      sym == defn.AnyClass     then "Ljava~lang~Object^"
    else if sym == defn.IntClass     then "I^"
    else if sym == defn.BooleanClass then "Z^"
    else if sym == defn.CharClass    then "C^"
    else if sym == defn.LongClass    then "L^"
    else if sym == defn.FloatClass   then "F^"
    else if sym == defn.DoubleClass  then "D^"
    else if sym == defn.ShortClass   then "S^"
    else if sym == defn.ByteClass    then "B^"
    else tpe match
      case defn.ArrayOf(elemType) =>
        s"?${typeString(elemType)}"
      case _ =>
        s"L${squiggles(className(sym))}^"

  /// other helpers

  private lazy val numericTypes =
    Set[Symbol](
      defn.IntClass, defn.DoubleClass, defn.FloatClass, defn.ShortClass,
      defn.ByteClass, defn.LongClass, defn.CharClass
    )

  def isNumericType(tpe: Type): Boolean =
    numericTypes(tpe.typeSymbol)

  def isReferenceType(tpe: Type): Boolean =
    tpe <:< defn.AnyRefType

  def isBooleanType(tpe: Type): Boolean =
    tpe.typeSymbol == defn.BooleanClass

  def isArrayType(tpe: Type): Boolean =
    tpe.typeSymbol == defn.ArrayClass

  // It's difficult to decide whether or not to translate
  // bridges, forwarders, mixed-in trait methods, interface default
  // methods... a tangle overlapping things. see
  // lightbend/scala-fortify#297 for some history and discussion.
  // perhaps we've gotten this correct now, or perhaps it will
  // need further work... time will tell.
  def shouldTranslate(dd: tpd.DefDef): Boolean =
    !dd.symbol.is(Bridge) &&
      !dd.symbol.isAnonymousFunction &&
      // Scala 3 generates this, 2 doesn't. it gets in the way of aligning
      // the results from the two compilers and I don't think omitting it
      // could affect vuln detection
      dd.name != nme.writeReplace

  def synthesizeMain(classDef: tpd.TypeDef): Option[STClassDecl] =
    val mainOption =
      classDef.symbol.info
        .member(nme.main)
        .altsWith(ctx.platform.isMainMethod)
        .headOption
        .map(_.symbol)
    mainOption.map: main =>
      val clazz = toClassDecl(classDef.symbol.companionModule)
      clazz.addFunction:
        val fun = new STFunDecl
        fun.setName:
          val clazzName = squiggles(className(classDef.symbol.companionModule))
          s"$clazzName~~main~S~~?Ljava~lang~String^"
        fun.addParameter(
          new STVarDecl(
            main.sourcePos,
            "args",
            new STType.STPointerType(
              main.sourcePos,
              new STType.STArrayType(
                main.sourcePos,
                new STType.STPointerType(
                  main.sourcePos,
                  new STType.STClassType(className(defn.StringClass)))))
          ))
        fun.setSimpleName("main")
        fun.setReturnType(VoidType)
        fun.setModifiers(NSTModifiers.Public, NSTModifiers.Static, NSTModifiers.Synthetic)
        fun.setBody:
          val block = new STBlock(main.sourcePos)
          block.add:
            val call = new STFunCall(main.sourcePos)
            call.setVirtual()
            call.setName(methodName(main))
            seen(classDef.symbol)
            call.addArgument(
              new STStaticFieldAccess(
                main.sourcePos,
                "MODULE$",
                new STType.STClassType(main.sourcePos, className(classDef.symbol))))
            call.addArgument(new STVarAccess(main.sourcePos, "args"))
            call
          block
        fun
      clazz

  /// operators

  val referenceOps: Map[Name, NSTOperators] =
    Map(
      nme.eq -> NSTOperators.Equal,
      nme.ne -> NSTOperators.NotEqual,
    )

  val arithmeticOps: Map[Name, NSTOperators] =
    Map(
      nme.Plus -> NSTOperators.Add,
      nme.Minus -> NSTOperators.Subtract,
      nme.Times -> NSTOperators.Multiply,
      nme.Div -> NSTOperators.Divide,
      nme.Mod -> NSTOperators.Mod,
      nme.AND -> NSTOperators.And,
      nme.OR -> NSTOperators.Or,
      nme.Xor -> NSTOperators.Xor,
      nme.Equals -> NSTOperators.Equal,
      nme.NotEquals -> NSTOperators.NotEqual,
      nme.Lt -> NSTOperators.LessThan,
      nme.Le -> NSTOperators.LessThanEqual,
      nme.Gt -> NSTOperators.GreaterThan,
      nme.Ge -> NSTOperators.GreaterThanEqual,
      nme.shiftSignedLeft -> NSTOperators.Lshift,
      nme.shiftSignedRight -> NSTOperators.Rshift,
      nme.shiftLogicalRight -> NSTOperators.Bwrshift
    )

  def conversionOps: Set[String] =
    Set("toDouble", "toLong", "toShort", "toByte", "toFloat", "toChar", "toInt")
