/* Hibernate, Relational Persistence for Idiomatic Java
 *
 * SPDX-License-Identifier: Apache-2.0
 * Copyright: Red Hat Inc. and Hibernate Authors
 */
package org.hibernate.reactive.persister.entity.impl;

import java.util.Iterator;

import org.hibernate.HibernateException;
import org.hibernate.MappingException;
import org.hibernate.boot.model.relational.SqlStringGenerationContext;
import org.hibernate.dialect.CockroachDB192Dialect;
import org.hibernate.dialect.DB2Dialect;
import org.hibernate.dialect.Dialect;
import org.hibernate.dialect.PostgreSQL81Dialect;
import org.hibernate.dialect.SQLServerDialect;
import org.hibernate.id.IdentityGenerator;
import org.hibernate.id.PostInsertIdentityPersister;
import org.hibernate.id.insert.IdentifierGeneratingInsert;
import org.hibernate.id.insert.InsertGeneratedIdentifierDelegate;
import org.hibernate.sql.Insert;

/**
 * Fix the insert and select id queries generated by Hibernate ORM
 */
public class ReactiveIdentityGenerator extends IdentityGenerator {

	@Override
	public InsertGeneratedIdentifierDelegate getInsertGeneratedIdentifierDelegate(
			PostInsertIdentityPersister persister, Dialect dialect, boolean isGetGeneratedKeysEnabled)
			throws HibernateException {
		return new ReactiveInsertAndSelectDelegate( persister, dialect );
	}

	public static class ReactiveInsertAndSelectDelegate extends InsertSelectDelegate {

		private final PostInsertIdentityPersister persister;
		private final Dialect dialect;

		public ReactiveInsertAndSelectDelegate(PostInsertIdentityPersister persister, Dialect dialect) {
			super( persister, dialect );
			this.persister = persister;
			this.dialect = dialect;
		}

		@Override
		public IdentifierGeneratingInsert prepareIdentifierGeneratingInsert(SqlStringGenerationContext context) {
			IdentifierGeneratingInsert insert = createInsert( context );
			insert.addIdentityColumn( persister.getRootTableKeyColumnNames()[0] );
			return insert;
		}

		private IdentifierGeneratingInsert createInsert(SqlStringGenerationContext context) {
			if ( dialect instanceof PostgreSQL81Dialect || dialect instanceof CockroachDB192Dialect ) {
				return new PostgresIdentifierGeneratingInsert( dialect );
			}
			if ( dialect instanceof SQLServerDialect ) {
				return new SqlServerIdentifierGeneratingInsert( dialect );
			}
			if ( dialect instanceof DB2Dialect ) {
				return new Db2IdentifierGeneratingInsert( dialect );
			}
			return super.prepareIdentifierGeneratingInsert( context );
		}
	}

	public static class Db2IdentifierGeneratingInsert extends IdentifierGeneratingInsert {

		private String identityColumnName;

		public Db2IdentifierGeneratingInsert(Dialect dialect) {
			super( dialect );
		}

		@Override
		public Insert addIdentityColumn(String columnName) {
			this.identityColumnName = columnName;
			return super.addIdentityColumn( columnName );
		}

		/**
		 * @see Insert#toStatementString()
		 */
		@Override
		public String toStatementString() {
			return "select " + identityColumnName + " from new table (" + super.toStatementString() + ")";
		}
	}

	public static class PostgresIdentifierGeneratingInsert extends IdentifierGeneratingInsert {

		private String identityColumnName;

		public PostgresIdentifierGeneratingInsert(Dialect dialect) {
			super( dialect );
		}

		@Override
		public Insert addIdentityColumn(String columnName) {
			this.identityColumnName = columnName;
			return super.addIdentityColumn( columnName );
		}

		@Override
		public String toStatementString() {
			return super.toStatementString() + " returning " + identityColumnName;
		}
	}

	public static class SqlServerIdentifierGeneratingInsert extends IdentifierGeneratingInsert {
		private String identityColumnName;

		public SqlServerIdentifierGeneratingInsert(Dialect dialect) {
			super( dialect );
		}

		@Override
		public Insert addIdentityColumn(String columnName) {
			this.identityColumnName = columnName;
			return super.addIdentityColumn( columnName );
		}

		/**
		 * @see Insert#toStatementString()
		 */
		public String toStatementString() {
			StringBuilder buf = new StringBuilder( columns.size() * 15 + tableName.length() + 10 );
			if ( comment != null ) {
				buf.append( "/* " ).append( Dialect.escapeComment( comment ) ).append( " */ " );
			}
			buf.append( "insert into " ).append( tableName );
			if ( columns.size() == 0 ) {
				if ( getDialect().supportsNoColumnsInsert() ) {
					// This line is missing in ORM
					buf.append( " output inserted." ).append( identityColumnName );
					buf.append( ' ' ).append( getDialect().getNoColumnsInsertString() );
				}
				else {
					throw new MappingException( String.format(
							"The INSERT statement for table [%s] contains no column, and this is not supported by [%s]",
							tableName,
							getDialect()
					)
					);
				}
			}
			else {
				buf.append( " (" );
				Iterator<String> iter = columns.keySet().iterator();
				while ( iter.hasNext() ) {
					buf.append( iter.next() );
					if ( iter.hasNext() ) {
						buf.append( ", " );
					}
				}
				buf.append( ")");
				// This line is missing in ORM
				buf.append( " output inserted." ).append( identityColumnName );
				buf.append( " values (" );
				iter = columns.values().iterator();
				while ( iter.hasNext() ) {
					buf.append( iter.next() );
					if ( iter.hasNext() ) {
						buf.append( ", " );
					}
				}
				buf.append( ')' );
			}
			return buf.toString();
		}
	}

}
