#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from optparse import OptionParser
from StringIO import StringIO

import cmd
import sys
import readline
import os
import re

try:
    import cql
except ImportError:
    sys.path.append(os.path.abspath(os.path.dirname(__file__)))
    import cql
from cql.results import ResultSet

HISTORY = os.path.join(os.path.expanduser('~'), '.cqlsh')
CQLTYPES = ("bytes", "ascii", "utf8", "timeuuid", "uuid", "long", "int")

RED = "\033[1;31m%s\033[0m"
GREEN = "\033[1;32m%s\033[0m"
BLUE = "\033[1;34m%s\033[0m"
YELLOW = "\033[1;33m%s\033[0m"
CYAN = "\033[1;36m%s\033[0m"
MAGENTA = "\033[1;35m%s\033[0m"

def startswith(words, text):
    return [i for i in words if i.startswith(text)]

class Shell(cmd.Cmd):
    default_prompt  = "cqlsh> "
    continue_prompt = "   ... "

    def __init__(self, hostname, port, color=False, username=None,
            password=None):
        cmd.Cmd.__init__(self)
        self.conn = cql.connect(hostname, port, user=username, password=password)
        self.cursor = self.conn.cursor()

        if os.path.exists(HISTORY):
            readline.read_history_file(HISTORY)

        if sys.stdin.isatty():
            self.prompt = Shell.default_prompt
        else:
            self.prompt = ""
        
        self.statement = StringIO()
        self.color = color
        self.in_comment = False
        self.in_batch = False
    
    def reset_statement(self):
        self.set_prompt(Shell.default_prompt)
        self.statement.truncate(0)
        
    def get_statement(self, line):
        if self.in_comment:
            if "*/" in line:
                fragment = line[line.index("*/")+2:]
                if fragment.strip():
                    line = fragment
                    self.in_comment = False
                else:
                    self.in_comment = False
                    self.set_prompt(Shell.default_prompt)
                    return None
            else:
                return None
        
        if "/*" in line and (not self.in_comment):
            self.in_comment = True
            self.set_prompt(Shell.continue_prompt)
            if line.lstrip().index("/*") != 0:
                self.statement.write(line[:line.lstrip().index("/*")])
            return None
        
        if ["BEGIN", "BATCH"] == line.upper().split()[:2]:
            self.set_prompt(Shell.continue_prompt)
            self.in_batch = True
            self.statement.write("%s\n" % line)
            return None
        elif (["APPLY", "BATCH"] == line.upper().split()) and self.in_batch:
            self.in_batch = False
            self.statement.write("%s\n" % line)
            try:
                return self.statement.getvalue()
            finally:
                self.reset_statement()
        
        self.statement.write("%s\n" % line)
        
        # The end of a statement    
        if not line.endswith(";") or self.in_batch:
            self.set_prompt(Shell.continue_prompt)
            return None
        
        try:    
            return self.statement.getvalue()
        finally:
            self.reset_statement()

    def default(self, arg):
        def scrub_oneline_comments(s):
            res = re.sub(r'\/\*.*\*\/', '', s)
            res = re.sub(r'--.*$', '', res)
            return res
        
        input = scrub_oneline_comments(arg)
        if not input.strip(): return
        statement = self.get_statement(input)
        if not statement: return

        self.cursor.execute(statement)

        if isinstance(self.cursor.result, ResultSet):
            for x in range(self.cursor.rowcount):
                row = self.cursor.fetchone()
                self.printout(repr(row[0]), BLUE, False)
                for (i, value) in enumerate(row[1:]):
                    name = self.cursor.description[i+1][0]
                    self.printout(" | ", newline=False)
                    self.printout(repr(name), MAGENTA, False)
                    self.printout(",", newline=False)
                    self.printout(repr(value), YELLOW, False)
                self.printout("")
        else:
            if self.cursor.result: print self.cursor.result[0]

    def emptyline(self):
        pass

    def complete_select(self, text, line, begidx, endidx):
        keywords = ('FIRST', 'REVERSED', 'FROM', 'WHERE', 'KEY')
        return startswith(keywords, text.upper())
    complete_SELECT = complete_select

    def complete_update(self, text, line, begidx, endidx):
        keywords = ('WHERE', 'KEY', 'SET')
        return startswith(keywords, text.upper())
    complete_UPDATE = complete_update

    def complete_create(self, text, line, begidx, endidx):
        words = line.split()
        if len(words) < 3:
            return startswith(['COLUMNFAMILY', 'KEYSPACE'], text.upper())

        common = ['WITH', 'AND']

        if words[1].upper() == 'COLUMNFAMILY':
            types = startswith(CQLTYPES, text)
            keywords = startswith(('KEY', 'PRIMARY'), text.upper())
            props = startswith(("comparator",
                                "comment",
                                "row_cache_size",
                                "key_cache_size",
                                "read_repair_chance",
                                "gc_grace_seconds",
                                "default_validation",
                                "min_compaction_threshold",
                                "max_compaction_threshold",
                                "row_cache_save_period_in_seconds",
                                "key_cache_save_period_in_seconds",
                                "memtable_flush_after_mins",
                                "memtable_throughput_in_mb",
                                "memtable_operations_in_millions",
                                "replicate_on_write"), text)
            return startswith(common, text.upper()) + types + keywords + props

        if words[1].upper() == 'KEYSPACE':
            props = ("replication_factor", "strategy_options", "strategy_class")
            return startswith(common, text.upper()) + startswith(props, text)
    complete_CREATE = complete_create

    def complete_drop(self, text, line, begidx, endidx):
        words = line.split()
        if len(words) < 3:
            return startswith(['COLUMNFAMILY', 'KEYSPACE'], text.upper())
    complete_DROP = complete_drop

    def completenames(self, text, *ignored):
        cmds = startswith(('USE', 'SELECT', 'UPDATE', 'DELETE', 'CREATE', 'DROP'),
                          text.upper())
        return cmd.Cmd.completenames(self, text, ignored) + cmds

    def set_prompt(self, prompt):
        if sys.stdin.isatty():
            self.prompt = prompt

    def do_EOF(self, arg):
        if sys.stdin.isatty(): print
        self.do_exit(None)

    def do_exit(self, arg):
        sys.exit()
    do_quit = do_exit

    def printout(self, text, color=None, newline=True, out=sys.stdout):
        if not color or not self.color:
            out.write(text)
        else:
            out.write(color % text)

        if newline:
            out.write("\n");

    def printerr(self, text, color=None, newline=True):
        self.printout(text, color, newline, sys.stderr)

if __name__ == '__main__':
    parser = OptionParser(usage = "Usage: %prog [host [port]]")
    parser.add_option("-C",
                      "--color",
                      action="store_true",
                      help="Enable color output.")
    parser.add_option("-u", "--username", help="Authenticate as user.")
    parser.add_option("-p", "--password", help="Authenticate using password.")
    (options, arguments) = parser.parse_args()

    hostname = len(arguments) > 0 and arguments[0] or "localhost"

    if len(arguments) > 1:
        try:
            port = int(arguments[1])
        except ValueError:
            print >>sys.stderr, "%s is not a valid port number" % arguments[1]
            parser.print_help(file=sys.stderr)
            sys.exit(1)
    else:
        port = 9160


    shell = Shell(hostname,
                  port,
                  color=options.color,
                  username=options.username,
                  password=options.password)
    while(True):
        try:
            shell.cmdloop()
        except SystemExit:
            readline.write_history_file(HISTORY)
            break
        except cql.Error, cqlerr:
            shell.printerr(str(cqlerr), color=RED)
        except KeyboardInterrupt:
            shell.reset_statement()
            print
        except Exception, err:
            shell.printerr("Exception: %s" % err, color=RED)

