#!/bin/sh
# -*- mode: Python -*-

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""":"
# bash code here; finds a suitable python interpreter and execs this file.
# prefer unqualified "python" if suitable:
python -c 'import sys; sys.exit(not (0x020500b0 < sys.hexversion < 0x03000000))' 2>/dev/null \
    && exec python "$0" "$@"
for pyver in 2.6 2.7 2.5; do
    which python$pyver > /dev/null 2>&1 && exec python$pyver "$0" "$@"
done
echo "No appropriate python interpreter found." >&2
exit 1
":"""

from __future__ import with_statement

description = "CQL Shell for Apache Cassandra"
version = "2.3.0"

from StringIO import StringIO
from itertools import groupby
from contextlib import contextmanager, closing
from glob import glob
from functools import partial
from uuid import UUID

import cmd
import sys
import os
import time
import optparse
import ConfigParser
import codecs
import locale
import platform
import warnings
import csv

try:
    import readline
except ImportError:
    readline = None

CQL_LIB_PREFIX = 'cql-internal-only-'
THRIFT_LIB_PREFIX = 'thrift-python-internal-only-'

# use bundled libs for python-cql and thrift, if available. if there
# is a ../lib dir, use bundled libs there preferentially.
ZIPLIB_DIRS = [os.path.join(os.path.dirname(__file__), '..', 'lib')]
myplatform = platform.system()
if myplatform == 'Linux':
    ZIPLIB_DIRS.append('/usr/share/cassandra/lib')

if os.environ.get('CQLSH_NO_BUNDLED', ''):
    ZIPLIB_DIRS = ()

def find_zip(libprefix):
    for ziplibdir in ZIPLIB_DIRS:
        zips = glob(os.path.join(ziplibdir, libprefix + '*.zip'))
        if zips:
            return max(zips)   # probably the highest version, if multiple

cql_zip = find_zip(CQL_LIB_PREFIX)
if cql_zip:
    ver = os.path.splitext(os.path.basename(cql_zip))[0][len(CQL_LIB_PREFIX):]
    sys.path.insert(0, os.path.join(cql_zip, 'cql-' + ver))
thrift_zip = find_zip(THRIFT_LIB_PREFIX)
if thrift_zip:
    sys.path.insert(0, thrift_zip)

try:
    import cql
except ImportError, e:
    sys.exit("\nPython CQL driver not installed, or not on PYTHONPATH.\n"
             'You might try "easy_install cql".\n\n'
             'Python: %s\n'
             'Module load path: %r\n\n'
             'Error: %s\n' % (sys.executable, sys.path, e))

import cql.decoders
from cql.cursor import _COUNT_DESCRIPTION, _VOID_DESCRIPTION
from cql.cqltypes import (cql_types, cql_typename, lookup_casstype, lookup_cqltype,
                          CassandraType)

# cqlsh should run correctly when run out of a Cassandra source tree,
# out of an unpacked Cassandra tarball, and after a proper package install.
cqlshlibdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'pylib')
if os.path.isdir(cqlshlibdir):
    sys.path.insert(0, cqlshlibdir)

from cqlshlib import cqlhandling, cql3handling, pylexotron
from cqlshlib.displaying import (RED, BLUE, ANSI_RESET, COLUMN_NAME_COLORS,
                                 FormattedValue, colorme)
from cqlshlib.formatting import format_by_type
from cqlshlib.util import trim_if_present
from cqlshlib.tracing import print_trace_session

CONFIG_FILE = os.path.expanduser(os.path.join('~', '.cqlshrc'))
HISTORY = os.path.expanduser(os.path.join('~', '.cqlsh_history'))
DEFAULT_HOST = 'localhost'
DEFAULT_PORT = 9160
DEFAULT_CQLVER = '3'
DEFAULT_TRANSPORT_FACTORY = 'cqlshlib.tfactory.regular_transport_factory'

DEFAULT_TIME_FORMAT = '%Y-%m-%d %H:%M:%S%z'
DEFAULT_FLOAT_PRECISION = 5

if readline is not None and 'libedit' in readline.__doc__:
    DEFAULT_COMPLETEKEY = '\t'
else:
    DEFAULT_COMPLETEKEY = 'tab'

epilog = """Connects to %(DEFAULT_HOST)s:%(DEFAULT_PORT)d by default. These
defaults can be changed by setting $CQLSH_HOST and/or $CQLSH_PORT. When a
host (and optional port number) are given on the command line, they take
precedence over any defaults.""" % globals()

parser = optparse.OptionParser(description=description, epilog=epilog,
                               usage="Usage: %prog [options] [host [port]]",
                               version='cqlsh ' + version)
parser.add_option("-C", "--color", action='store_true', dest='color',
                  help='Always use color output')
parser.add_option("--no-color", action='store_false', dest='color',
                  help='Never use color output')
parser.add_option("-u", "--username", help="Authenticate as user.")
parser.add_option("-p", "--password", help="Authenticate using password.")
parser.add_option('-k', '--keyspace', help='Authenticate to the given keyspace.')
parser.add_option("-f", "--file", help="Execute commands from FILE, then exit")
parser.add_option("-t", "--transport-factory",
                  help="Use the provided Thrift transport factory function.")
parser.add_option('--debug', action='store_true',
                  help='Show additional debugging information')
parser.add_option('--cqlversion', default=DEFAULT_CQLVER,
                  help='Specify a particular CQL version (default: %default).'
                       ' Examples: "2", "3.0.0-beta1"')
parser.add_option("-2", "--cql2", action="store_const", dest='cqlversion', const='2',
                  help="Shortcut notation for --cqlversion=2")
parser.add_option("-3", "--cql3", action="store_const", dest='cqlversion', const='3',
                  help="Shortcut notation for --cqlversion=3")


CQL_ERRORS = (cql.Error,)
try:
    from thrift.Thrift import TException
except ImportError:
    pass
else:
    CQL_ERRORS += (TException,)

debug_completion = bool(os.environ.get('CQLSH_DEBUG_COMPLETION', '') == 'YES')

SYSTEM_KEYSPACES = ('system', 'system_traces')

# we want the cql parser to understand our cqlsh-specific commands too
my_commands_ending_with_newline = (
    'help',
    '?',
    'consistency',
    'describe',
    'desc',
    'show',
    'assume',
    'source',
    'capture',
    'debug',
    'tracing',
    'exit',
    'quit'
)

cqlsh_syntax_completers = []
def cqlsh_syntax_completer(rulename, termname):
    def registrator(f):
        cqlsh_syntax_completers.append((rulename, termname, f))
        return f
    return registrator

cqlsh_extra_syntax_rules = r'''
<cqlshCommand> ::= <CQL_Statement>
                 | <specialCommand> ( ";" | "\n" )
                 ;

<specialCommand> ::= <describeCommand>
                   | <consistencyCommand>
                   | <showCommand>
                   | <assumeCommand>
                   | <sourceCommand>
                   | <captureCommand>
                   | <copyCommand>
                   | <debugCommand>
                   | <helpCommand>
                   | <tracingCommand>
                   | <exitCommand>
                   ;

<describeCommand> ::= ( "DESCRIBE" | "DESC" )
                                  ( "KEYSPACE" ksname=<keyspaceName>?
                                  | ( "COLUMNFAMILY" | "TABLE" ) cf=<columnFamilyName>
                                  | ( "COLUMNFAMILIES" | "TABLES" )
                                  | "SCHEMA"
                                  | "CLUSTER" )
                    ;

<consistencyCommand> ::= "CONSISTENCY" ( level=<consistencyLevel> )?
                       ;

<consistencyLevel> ::= "ANY"
                     | "ONE"
                     | "TWO"
                     | "THREE"
                     | "QUORUM"
                     | "ALL"
                     | "LOCAL_QUORUM"
                     | "EACH_QUORUM"
                     ;

<showCommand> ::= "SHOW" what=( "VERSION" | "HOST" | "ASSUMPTIONS" )
                ;

<assumeCommand> ::= "ASSUME" cf=<columnFamilyName> <assumeTypeDef>
                                                   ( "," <assumeTypeDef> )*
                  ;

<assumeTypeDef> ::= "NAMES" "ARE" names=<storageType>
                  | "VALUES" "ARE" values=<storageType>
                  | "(" colname=<colname> ")" "VALUES" "ARE" colvalues=<storageType>
                  ;

<sourceCommand> ::= "SOURCE" fname=<stringLiteral>
                  ;

<captureCommand> ::= "CAPTURE" ( fname=( <stringLiteral> | "OFF" ) )?
                   ;

<copyCommand> ::= "COPY" cf=<columnFamilyName>
                         ( "(" [colnames]=<colname> ( "," [colnames]=<colname> )* ")" )?
                         ( dir="FROM" ( fname=<stringLiteral> | "STDIN" )
                         | dir="TO"   ( fname=<stringLiteral> | "STDOUT" ) )
                         ( "WITH" <copyOption> ( "AND" <copyOption> )* )?
                ;

<copyOption> ::= [optnames]=<identifier> "=" [optvals]=<copyOptionVal>
               ;

<copyOptionVal> ::= <identifier>
                  | <stringLiteral>
                  ;

# avoiding just "DEBUG" so that this rule doesn't get treated as a terminal
<debugCommand> ::= "DEBUG" "THINGS"?
                 ;

<helpCommand> ::= ( "HELP" | "?" ) [topic]=( /[a-z_]*/ )*
                ;

<tracingCommand> ::= "TRACING" ( switch=( "ON" | "OFF" ) )?
                   ;

<exitCommand> ::= "exit" | "quit"
                ;

<qmark> ::= "?" ;
'''

@cqlsh_syntax_completer('helpCommand', 'topic')
def complete_help(ctxt, cqlsh):
    return sorted([ t.upper() for t in cqldocs.get_help_topics() + cqlsh.get_help_topics() ])

@cqlsh_syntax_completer('assumeTypeDef', 'colname')
def complete_assume_col(ctxt, cqlsh):
    ks = ctxt.get_binding('ks', None)
    ks = cqlsh.cql_unprotect_name(ks) if ks is not None else None
    cf = cqlsh.cql_unprotect_name(ctxt.get_binding('cf'))
    cfdef = cqlsh.get_columnfamily(cf, ksname=ks)
    cols = [cm.name for cm in cfdef.column_metadata]
    cols.append(cfdef.key_alias or 'KEY')
    return map(cqlsh.cql_protect_name, cols)

def complete_source_quoted_filename(ctxt, cqlsh):
    partial = ctxt.get_binding('partial', '')
    head, tail = os.path.split(partial)
    exhead = os.path.expanduser(head)
    try:
        contents = os.listdir(exhead or '.')
    except OSError:
        return ()
    matches = filter(lambda f: f.startswith(tail), contents)
    annotated = []
    for f in matches:
        match = os.path.join(head, f)
        if os.path.isdir(os.path.join(exhead, f)):
            match += '/'
        annotated.append(match)
    return annotated

cqlsh_syntax_completer('sourceCommand', 'fname') \
        (complete_source_quoted_filename)
cqlsh_syntax_completer('captureCommand', 'fname') \
        (complete_source_quoted_filename)

@cqlsh_syntax_completer('copyCommand', 'fname')
def copy_fname_completer(ctxt, cqlsh):
    lasttype = ctxt.get_binding('*LASTTYPE*')
    if lasttype == 'unclosedString':
        return complete_source_quoted_filename(ctxt, cqlsh)
    partial = ctxt.get_binding('partial')
    if partial == '':
        return ["'"]
    return ()

@cqlsh_syntax_completer('copyCommand', 'colnames')
def complete_copy_column_names(ctxt, cqlsh):
    existcols = map(cqlsh.cql_unprotect_name, ctxt.get_binding('colnames', ()))
    ks = cqlsh.cql_unprotect_name(ctxt.get_binding('ksname', None))
    cf = cqlsh.cql_unprotect_name(ctxt.get_binding('cfname'))
    colnames = cqlsh.get_column_names(ks, cf)
    if len(existcols) == 0:
        return [colnames[0]]
    return set(colnames[1:]) - set(existcols)

COPY_OPTIONS = ('DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'ENCODING', 'NULL')

@cqlsh_syntax_completer('copyOption', 'optnames')
def complete_copy_options(ctxt, cqlsh):
    optnames = map(str.upper, ctxt.get_binding('optnames', ()))
    direction = ctxt.get_binding('dir').upper()
    opts = set(COPY_OPTIONS) - set(optnames)
    if direction == 'FROM':
        opts -= ('ENCODING', 'NULL')
    return opts

@cqlsh_syntax_completer('copyOption', 'optvals')
def complete_copy_opt_values(ctxt, cqlsh):
    optnames = ctxt.get_binding('optnames', ())
    lastopt = optnames[-1].lower()
    if lastopt == 'header':
        return ['true', 'false']
    return [cqlhandling.Hint('<single_character_string>')]

class NoKeyspaceError(Exception):
    pass

class KeyspaceNotFound(Exception):
    pass

class ColumnFamilyNotFound(Exception):
    pass

class VersionNotSupported(Exception):
    pass

class DecodeError(Exception):
    def __init__(self, thebytes, err, expectedtype, colname=None):
        self.thebytes = thebytes
        self.err = err
        self.expectedtype = expectedtype
        self.colname = colname

    def __str__(self):
        return str(self.thebytes)

    def message(self):
        what = 'column name %r' % (self.thebytes,)
        if self.colname is not None:
            what = 'value %r (for column %r)' % (self.thebytes, self.colname)
        return 'Failed to decode %s as %s: %s' % (what, self.expectedtype, self.err)

    def __repr__(self):
        return '<%s %s>' % (self.__class__.__name__, self.message())

def full_cql_version(ver):
    while ver.count('.') < 2:
        ver += '.0'
    ver_parts = ver.split('-', 1) + ['']
    vertuple = tuple(map(int, ver_parts[0].split('.')) + [ver_parts[1]])
    return ver, vertuple

def format_value(val, typeclass, output_encoding, addcolor=False, time_format=None,
                 float_precision=None, colormap=None, nullval=None):
    if isinstance(val, DecodeError):
        if addcolor:
            return colorme(val.thebytes, colormap, 'hex')
        else:
            return FormattedValue(val.thebytes)
    if not issubclass(typeclass, CassandraType):
        typeclass = lookup_casstype(typeclass)
    return format_by_type(typeclass, val, output_encoding, colormap=colormap,
                          addcolor=addcolor, nullval=nullval, time_format=time_format,
                          float_precision=float_precision)

def show_warning_without_quoting_line(message, category, filename, lineno, file=None, line=None):
    if file is None:
        file = sys.stderr
    try:
        file.write(warnings.formatwarning(message, category, filename, lineno, line=''))
    except IOError:
        pass
warnings.showwarning = show_warning_without_quoting_line
warnings.filterwarnings('always', category=cql3handling.UnexpectedTableStructure)

def describe_interval(seconds):
    desc = []
    for length, unit in ((86400, 'day'), (3600, 'hour'), (60, 'minute')):
        num = int(seconds) / length
        if num > 0:
            desc.append('%d %s' % (num, unit))
            if num > 1:
                desc[-1] += 's'
        seconds %= length
    words = '%.03f seconds' % seconds
    if len(desc) > 1:
        words = ', '.join(desc) + ', and ' + words
    elif len(desc) == 1:
        words = desc[0] + ' and ' + words
    return words

class Shell(cmd.Cmd):
    default_prompt  = "cqlsh> "
    continue_prompt = "   ... "
    keyspace_prompt          = "cqlsh:%s> "
    keyspace_continue_prompt = "%s    ... "
    num_retries = 4
    show_line_nums = False
    debug = False
    stop = False
    last_hist = None
    shunted_query_out = None
    csv_dialect_defaults = dict(delimiter=',', doublequote=False,
                                escapechar='\\', quotechar='"')

    def __init__(self, hostname, port, transport_factory, color=False,
                 username=None, password=None, encoding=None, stdin=None, tty=True,
                 completekey=DEFAULT_COMPLETEKEY, use_conn=None,
                 cqlver=None, keyspace=None, tracing_enabled=False,
                 display_time_format=DEFAULT_TIME_FORMAT,
                 display_float_precision=DEFAULT_FLOAT_PRECISION):
        cmd.Cmd.__init__(self, completekey=completekey)
        self.hostname = hostname
        self.port = port
        self.transport_factory = transport_factory
        self.username = username
        self.password = password
        self.keyspace = keyspace
        self.tracing_enabled = tracing_enabled
        if use_conn is not None:
            self.conn = use_conn
        else:
            transport = transport_factory(hostname, port, os.environ, CONFIG_FILE)
            self.conn = cql.connect(hostname, port, user=username, password=password,
                                    cql_version=cqlver, transport=transport)
            self.set_expanded_cql_version(cqlver)
            # we could set the keyspace through cql.connect(), but as of 1.0.10,
            # it doesn't quote the keyspace for USE :(
            if keyspace is not None:
                tempcurs = self.conn.cursor()
                tempcurs.execute('USE %s;' % self.cql_protect_name(keyspace))
                tempcurs.close()
        self.cursor = self.conn.cursor()
        self.get_connection_versions()

        # use 3.0.0-beta1 syntax if explicitly requested, or if using
        # cassandra < 1.2. this only affects use of cql3; cql2 syntax
        # in either case is the same.
        if self.cassandraver_atleast(1, 2) and not self.is_cql3_beta():
            cql3handling.use_post_3_0_0_syntax()
        else:
            cql3handling.use_pre_3_0_0_syntax()

        self.current_keyspace = keyspace

        self.color = color
        self.display_time_format = display_time_format
        self.display_float_precision = display_float_precision
        if encoding is None:
            encoding = locale.getpreferredencoding()
        self.encoding = encoding
        self.output_codec = codecs.lookup(encoding)

        self.statement = StringIO()
        self.lineno = 1
        self.in_comment = False
        self.schema_overrides = {}

        self.prompt = ''
        if stdin is None:
            stdin = sys.stdin
        self.tty = tty
        if tty:
            self.reset_prompt()
            self.report_connection()
            print 'Use HELP for help.'
        else:
            self.show_line_nums = True
        self.stdin = stdin
        self.query_out = sys.stdout

    def set_expanded_cql_version(self, ver):
        ver, vertuple = full_cql_version(ver)
        self.set_cql_version(ver)
        self.cql_version = ver
        self.cql_ver_tuple = vertuple

    def is_cql3_beta(self):
        return self.cql_ver_tuple == (3, 0, 0, 'beta1')

    def cqlver_atleast(self, major, minor=0, patch=0):
        return self.cql_ver_tuple[:3] >= (major, minor, patch)

    def cassandraver_atleast(self, major, minor=0, patch=0):
        return self.cass_ver_tuple[:3] >= (major, minor, patch)

    def myformat_value(self, val, casstype, **kwargs):
        if isinstance(val, DecodeError):
            self.decoding_errors.append(val)
        return format_value(val, casstype, self.output_codec.name,
                            addcolor=self.color, time_format=self.display_time_format,
                            float_precision=self.display_float_precision, **kwargs)

    def myformat_colname(self, name, nametype):
        return self.myformat_value(name, nametype, colormap=COLUMN_NAME_COLORS)

    def report_connection(self):
        self.show_host()
        self.show_version()

    def show_host(self):
        print "Connected to %s at %s:%d." % \
               (self.applycolor(self.get_cluster_name(), BLUE),
                self.hostname,
                self.port)

    def show_version(self):
        vers = self.connection_versions.copy()
        vers['shver'] = version
        # system.Versions['cql'] apparently does not reflect changes with
        # set_cql_version.
        vers['cql'] = self.cql_version
        print "[cqlsh %(shver)s | Cassandra %(build)s | CQL spec %(cql)s | Thrift protocol %(thrift)s]" % vers

    def show_assumptions(self):
        all_overrides = self.schema_overrides.items()
        all_overrides.sort()
        if all_overrides:
            print
        else:
            print 'No overrides.'
            return
        for keyspace, ksoverrides in groupby(all_overrides, key=lambda x:x[0][0]):
            keyspace = self.cql_protect_name(keyspace)
            print 'USE %s;' % keyspace
            print
            for (ks, cf), override in ksoverrides:
                cf = self.cql_protect_name(cf)
                if override.default_name_type:
                    print 'ASSUME %s NAMES ARE %s;' \
                          % (cf, cql_typename(override.default_name_type))
                if override.default_value_type:
                    print 'ASSUME %s VALUES ARE %s;' \
                          % (cf, cql_typename(override.default_value_type))
                for colname, vtype in override.value_types.items():
                    colname = self.cql_protect_name(colname)
                    print 'ASSUME %s(%s) VALUES ARE %s;' \
                          % (cf, colname, cql_typename(vtype))
        print

    def get_connection_versions(self):
        try:
            self.cursor.execute("select * from system.local where key = 'local'")
        except cql.ProgrammingError:
            vers = self.get_connection_versions_fallback()
        else:
            result = self.fetchdict()
            vers = {
                'build': result['release_version'],
                'thrift': result['thrift_version'],
                'cql': result['cql_version'],
            }
        self.connection_versions = vers
        self.cass_ver_tuple = tuple(map(int, vers['build'].split('-', 1)[0].split('.', 2)))

    def get_connection_versions_fallback(self):
        if self.cqlver_atleast(3):
            query = 'select component, version from system."Versions"'
        else:
            query = 'select component, version from system.Versions'
        try:
            self.cursor.execute(query)
            vers = dict(self.cursor)
        except cql.ProgrammingError:
            # older Cassandra; doesn't have system.Versions
            thrift_ver = self.get_thrift_version()
            return {'build': '0.0.0', 'cql': 'unknown', 'thrift': thrift_ver}
        return vers

    def fetchdict(self):
        row = self.cursor.fetchone()
        if row is None:
            return None
        desc = self.cursor.description
        return dict(zip([d[0] for d in desc], row))

    def fetchdict_all(self):
        dicts = []
        for row in self.cursor:
            desc = self.cursor.description
            dicts.append(dict(zip([d[0] for d in desc], row)))
        return dicts

    def get_keyspace_names(self):
        return [k.name for k in self.get_keyspaces()]

    def get_columnfamilies(self, ksname=None):
        if ksname is None:
            ksname = self.current_keyspace
            if ksname is None:
                raise NoKeyspaceError("Not in any keyspace.")
        return self.get_keyspace(ksname).cf_defs

    def get_columnfamily(self, cfname, ksname=None):
        if ksname is None:
            ksname = self.current_keyspace
        cf_defs = self.get_columnfamilies(ksname)
        for c in cf_defs:
            if c.name == cfname:
                return c
        raise ColumnFamilyNotFound("Unconfigured column family %r" % (cfname,))

    def get_columnfamily_names(self, ksname=None):
        if self.cqlver_atleast(3):
            return self.get_columnfamily_names_cql3(ksname=ksname)
        return [c.name for c in self.get_columnfamilies(ksname)]

    def get_index_names(self, ksname=None):
        indnames = []
        for c in self.get_columnfamilies(ksname):
            for md in c.column_metadata:
                if md.index_name is not None:
                    indnames.append(md.index_name)
        return indnames

    def filterable_column_names(self, cfdef):
        filterable = set()
        if cfdef.key_alias is not None and cfdef.key_alias != 'KEY':
            filterable.add(cfdef.key_alias)
        else:
            filterable.add('KEY')
        for cm in cfdef.column_metadata:
            if cm.index_name is not None:
                filterable.add(cm.name)
        return filterable

    def get_column_names(self, ksname, cfname):
        if ksname is None:
            ksname = self.current_keyspace
        if ksname not in SYSTEM_KEYSPACES and self.cqlver_atleast(3):
            return self.get_column_names_from_layout(ksname, cfname)
        else:
            return self.get_column_names_from_cfdef(ksname, cfname)

    def get_column_names_from_layout(self, ksname, cfname):
        layout = self.get_columnfamily_layout(ksname, cfname)
        return [col.name for col in layout.columns]

    def get_column_names_from_cfdef(self, ksname, cfname):
        cfdef = self.get_columnfamily(cfname, ksname=ksname)
        key_alias = cfdef.key_alias
        if key_alias is None:
            key_alias = 'KEY'
        return [key_alias] + sorted([cm.name for cm in cfdef.column_metadata])

    # ===== thrift-dependent parts =====

    def get_cluster_name(self):
        return self.make_hacktastic_thrift_call('describe_cluster_name')

    def get_partitioner(self):
        return self.make_hacktastic_thrift_call('describe_partitioner')

    def get_snitch(self):
        return self.make_hacktastic_thrift_call('describe_snitch')

    def get_thrift_version(self):
        return self.make_hacktastic_thrift_call('describe_version')

    def get_ring(self):
        if self.current_keyspace is None or self.current_keyspace in SYSTEM_KEYSPACES:
            raise NoKeyspaceError("Ring view requires a current non-system keyspace")
        return self.make_hacktastic_thrift_call('describe_ring', self.current_keyspace)

    def get_keyspace(self, ksname):
        try:
            return self.make_hacktastic_thrift_call('describe_keyspace', ksname)
        except cql.cassandra.ttypes.NotFoundException:
            raise KeyspaceNotFound('Keyspace %r not found.' % ksname)

    def get_keyspaces(self):
        return self.make_hacktastic_thrift_call('describe_keyspaces')

    def get_schema_versions(self):
        return self.make_hacktastic_thrift_call('describe_schema_versions')

    def set_cql_version(self, ver):
        try:
            return self.make_hacktastic_thrift_call('set_cql_version', ver)
        except cql.cassandra.ttypes.InvalidRequestException, e:
            raise VersionNotSupported(e.why)

    def trace_next_query(self):
        return self.make_hacktastic_thrift_call('trace_next_query')

    def make_hacktastic_thrift_call(self, call, *args):
        client = self.conn.client
        return getattr(client, call)(*args)

    # ===== end thrift-dependent parts =====

    # ===== cql3-dependent parts =====

    def get_columnfamily_names_cql3(self, ksname=None):
        if ksname is None:
            ksname = self.current_keyspace
        if self.cassandraver_atleast(1, 2):
            cf_q = """select columnfamily_name from system.schema_columnfamilies
                       where keyspace_name=:ks"""
        else:
            cf_q = """select "columnfamily" from system.schema_columnfamilies
                       where "keyspace"=:ks"""
        self.cursor.execute(cf_q, {'ks': self.cql_unprotect_name(ksname)})
        return [str(row[0]) for row in self.cursor.fetchall()]

    def get_columnfamily_layout(self, ksname, cfname):
        if ksname is None:
            ksname = self.current_keyspace
        if self.cassandraver_atleast(1, 2):
            cf_q = """select * from system.schema_columnfamilies
                       where keyspace_name=:ks and columnfamily_name=:cf"""
            col_q = """select * from system.schema_columns
                        where keyspace_name=:ks and columnfamily_name=:cf"""
        else:
            cf_q = """select * from system.schema_columnfamilies
                       where "keyspace"=:ks and "columnfamily"=:cf"""
            col_q = """select * from system.schema_columns
                        where "keyspace"=:ks and "columnfamily"=:cf"""
        self.cursor.execute(cf_q, {'ks': ksname, 'cf': cfname})
        layout = self.fetchdict()
        if layout is None:
            raise ColumnFamilyNotFound("Column family %r not found" % cfname)
        self.cursor.execute(col_q, {'ks': ksname, 'cf': cfname})
        cols = self.fetchdict_all()
        return cql3handling.CqlTableDef.from_layout(layout, cols)

    # ===== end cql3-dependent parts =====

    def reset_statement(self):
        self.reset_prompt()
        self.statement.truncate(0)

    def reset_prompt(self):
        if self.current_keyspace is None:
            self.set_prompt(self.default_prompt)
        else:
            self.set_prompt(self.keyspace_prompt % self.current_keyspace)

    def set_continue_prompt(self):
        if self.current_keyspace is None:
            self.set_prompt(self.continue_prompt)
        else:
            spaces = ' ' * len(str(self.current_keyspace))
            self.set_prompt(self.keyspace_continue_prompt % spaces)

    @contextmanager
    def prepare_loop(self):
        readline = None
        if self.tty and self.completekey:
            try:
                import readline
            except ImportError:
                pass
            else:
                old_completer = readline.get_completer()
                readline.set_completer(self.complete)
                if 'libedit' in readline.__doc__:
                    readline.parse_and_bind("bind -e")
                    readline.parse_and_bind("bind '" + self.completekey + "' rl_complete")
                else:
                    readline.parse_and_bind(self.completekey + ": complete")
        try:
            yield
        finally:
            if readline is not None:
                readline.set_completer(old_completer)

    def get_input_line(self, prompt=''):
        if self.tty:
            line = raw_input(prompt) + '\n'
        else:
            line = self.stdin.readline()
            if not len(line):
                raise EOFError
        self.lineno += 1
        return line

    def use_stdin_reader(self, until='', prompt=''):
        until += '\n'
        while True:
            try:
                newline = self.get_input_line(prompt=prompt)
            except EOFError:
                return
            if newline == until:
                return
            yield newline

    def cmdloop(self):
        """
        Adapted from cmd.Cmd's version, because there is literally no way with
        cmd.Cmd.cmdloop() to tell the difference between "EOF" showing up in
        input and an actual EOF.
        """
        with self.prepare_loop():
            while not self.stop:
                try:
                    line = self.get_input_line(self.prompt)
                    self.statement.write(line)
                    if self.onecmd(self.statement.getvalue()):
                        self.reset_statement()
                except EOFError:
                    self.handle_eof()
                except cql.Error, cqlerr:
                    self.printerr(str(cqlerr))
                except KeyboardInterrupt:
                    self.reset_statement()
                    print

    def onecmd(self, statementtext):
        """
        Returns true if the statement is complete and was handled (meaning it
        can be reset).
        """

        try:
            statements, in_batch = cqlruleset.cql_split_statements(statementtext)
        except pylexotron.LexingError, e:
            if self.show_line_nums:
                self.printerr('Invalid syntax at char %d' % (e.charnum,))
            else:
                self.printerr('Invalid syntax at line %d, char %d'
                              % (e.linenum, e.charnum))
            statementline = statementtext.split('\n')[e.linenum - 1]
            self.printerr('  %s' % statementline)
            self.printerr(' %s^' % (' ' * e.charnum))
            return True

        while statements and not statements[-1]:
            statements = statements[:-1]
        if not statements:
            return True
        if in_batch or statements[-1][-1][0] != 'endtoken':
            self.set_continue_prompt()
            return
        for st in statements:
            try:
                self.handle_statement(st, statementtext)
            except Exception, e:
                if self.debug:
                    import traceback
                    traceback.print_exc()
                else:
                    self.printerr(e)
        return True

    def handle_eof(self):
        if self.tty:
            print
        statement = self.statement.getvalue()
        if statement.strip():
            if not self.onecmd(statement + ';'):
                self.printerr('Incomplete statement at end of file')
        self.do_exit()

    def handle_statement(self, tokens, srcstr):
        # Concat multi-line statements and insert into history
        if readline is not None:
            nl_count = srcstr.count("\n")

            new_hist = srcstr.replace("\n", " ").rstrip()

            if nl_count > 1 and self.last_hist != new_hist:
                readline.add_history(new_hist)

            self.last_hist = new_hist
        cmdword = tokens[0][1]
        if cmdword == '?':
            cmdword = 'help'
        custom_handler = getattr(self, 'do_' + cmdword.lower(), None)
        if custom_handler:
            parsed = cqlruleset.cql_whole_parse_tokens(tokens, srcstr=srcstr,
                                                       startsymbol='cqlshCommand')
            if parsed and not parsed.remainder:
                # successful complete parse
                return custom_handler(parsed)
            else:
                return self.handle_parse_error(cmdword, tokens, parsed, srcstr)
        return self.perform_statement(cqlruleset.cql_extract_orig(tokens, srcstr))

    def handle_parse_error(self, cmdword, tokens, parsed, srcstr):
        if cmdword.lower() in ('select', 'insert', 'update', 'delete', 'truncate',
                               'create', 'drop', 'alter', 'grant', 'revoke',
                               'batch', 'list'):
            # hey, maybe they know about some new syntax we don't. type
            # assumptions won't work, but maybe the query will.
            return self.perform_statement(cqlruleset.cql_extract_orig(tokens, srcstr))
        if parsed:
            self.printerr('Improper %s command (problem at %r).' % (cmdword, parsed.remainder[0]))
        else:
            self.printerr('Improper %s command.' % cmdword)

    def do_use(self, parsed):
        ksname = parsed.get_binding('ksname')
        if self.perform_statement_untraced(parsed.extract_orig()):
            self.current_keyspace = self.cql_unprotect_name(ksname)

    def do_select(self, parsed):
        ksname = parsed.get_binding('ksname')
        if ksname is not None:
            ksname = self.cql_unprotect_name(ksname)
        cfname = self.cql_unprotect_name(parsed.get_binding('cfname'))
        decoder = self.determine_decoder_for(cfname, ksname=ksname)
        self.perform_statement(parsed.extract_orig(), decoder=decoder)

    def perform_statement(self, statement, decoder=None):
        if self.tracing_enabled:
            session_id = UUID(bytes=self.trace_next_query())
            result = self.perform_statement_untraced(statement, decoder=None)
            time.sleep(0.5) # trace writes are async so we wait a little.
            print_trace_session(self, self.cursor, session_id)
            return result
        else:
            return self.perform_statement_untraced(statement, decoder=None)

    def perform_statement_untraced(self, statement, decoder=None):
        if not statement:
            return False
        trynum = 1
        while True:
            try:
                self.cursor.execute(statement, decoder=decoder)
                break
            except cql.IntegrityError, err:
                self.printerr("Attempt #%d: %s" % (trynum, str(err)))
                trynum += 1
                if trynum > self.num_retries:
                    return False
                time.sleep(1*trynum)
            except cql.ProgrammingError, err:
                self.printerr(str(err))
                # try reparsing as cql3; if successful, suggest -3
                if self.cqlver_atleast(3):
                    if self.parseable_as_cql2(statement):
                        self.printerr("Perhaps you meant to use CQL 2? Try using"
                                      " the -2 option when starting cqlsh.")
                else:
                    if self.parseable_as_cql3(statement):
                        self.printerr("Perhaps you meant to use CQL 3? Try using"
                                      " the -3 option when starting cqlsh.")
                return False
            except CQL_ERRORS, err:
                self.printerr(str(err))
                return False
            except Exception, err:
                import traceback
                self.printerr(traceback.format_exc())
                return False

        if self.cursor.description is _COUNT_DESCRIPTION:
            self.print_count_result(self.cursor)
        elif self.cursor.description is not _VOID_DESCRIPTION:
            self.print_result(self.cursor)
        self.flush_output()
        return True

    # these next two functions are not guaranteed perfect; just checks if the
    # statement parses fully according to cqlsh's own understanding of the
    # grammar. Changes to the language in Cassandra frequently don't get
    # updated in cqlsh right away.

    def parseable_as_cql3(self, statement):
        return cql3handling.CqlRuleSet.lex_and_whole_match(statement) is not None

    def parseable_as_cql2(self, statement):
        return cqlhandling.CqlRuleSet.lex_and_whole_match(statement) is not None

    def determine_decoder_for(self, cfname, ksname=None):
        decoder = ErrorHandlingSchemaDecoder
        if ksname is None:
            ksname = self.current_keyspace
        overrides = self.schema_overrides.get((ksname, cfname), None)
        if overrides:
            decoder = partial(decoder, overrides=overrides)
        return decoder

    def get_nametype(self, cursor, num):
        """
        Determine the Cassandra type of a column name from the current row of
        query results on the given cursor. The column in question is given by
        its zero-based ordinal number within the row.

        This is necessary to differentiate some things like ascii vs. blob hex.
        """

        return cursor.name_info[num][1]

    def print_count_result(self, cursor):
        if not cursor.result:
            return
        self.writeresult('count')
        self.writeresult('-----')
        self.writeresult(cursor.result[0])
        self.writeresult("")

    def has_static_result_set(self, cursor):
        if self.cqlver_atleast(3):
            return True  # all cql3 resultsets are static, don't bother scanning
        last_description = None
        for row in cursor:
            if last_description is not None and cursor.description != last_description:
                return False
            last_description = cursor.description
        cursor._reset()
        return True

    def print_result(self, cursor):
        self.decoding_errors = []

        self.writeresult("")
        if self.has_static_result_set(cursor):
            self.print_static_result(cursor)
        else:
            self.print_dynamic_result(cursor)
        self.writeresult("")

        if self.decoding_errors:
            for err in self.decoding_errors[:2]:
                self.writeresult(err.message(), color=RED)
            if len(self.decoding_errors) > 2:
                self.writeresult('%d more decoding errors suppressed.'
                                 % (len(self.decoding_errors) - 2), color=RED)

    def print_static_result(self, cursor):
        colnames = [d[0] for d in cursor.description]
        colnames_t = [(name, self.get_nametype(cursor, n)) for (n, name) in enumerate(colnames)]
        formatted_names = [self.myformat_colname(name, nametype) for (name, nametype) in colnames_t]
        formatted_values = [map(self.myformat_value, row, cursor.column_types) for row in cursor]
        self.print_formatted_result(formatted_names, formatted_values)

    def print_formatted_result(self, formatted_names, formatted_values):
        # determine column widths
        widths = [n.displaywidth for n in formatted_names]
        for fmtrow in formatted_values:
            for num, col in enumerate(fmtrow):
                widths[num] = max(widths[num], col.displaywidth)

        # print header
        header = ' | '.join(hdr.color_ljust(w) for (hdr, w) in zip(formatted_names, widths))
        self.writeresult(' ' + header.rstrip())
        self.writeresult('-%s-' % '-+-'.join('-' * w for w in widths))

        # print row data
        for row in formatted_values:
            line = ' | '.join(col.color_rjust(w) for (col, w) in zip(row, widths))
            self.writeresult(' ' + line)

    def print_dynamic_result(self, cursor):
        for row in cursor:
            colnames = [d[0] for d in cursor.description]
            colnames_t = [(name, self.get_nametype(cursor, n)) for (n, name) in enumerate(colnames)]
            colnames = [self.myformat_colname(name, nametype) for (name, nametype) in colnames_t]
            colvals = map(self.myformat_value, row, cursor.column_types)
            line = ' | '.join('%s,%s' % (n.coloredval, v.coloredval) for (n, v) in zip(colnames, colvals))
            self.writeresult(' ' + line)

    def emptyline(self):
        pass

    def parseline(self, line):
        # this shouldn't be needed
        raise NotImplementedError

    def complete(self, text, state):
        if readline is None:
            return
        if state == 0:
            try:
                self.completion_matches = self.find_completions(text)
            except Exception:
                if debug_completion:
                    import traceback
                    traceback.print_exc()
                else:
                    raise
        try:
            return self.completion_matches[state]
        except IndexError:
            return None

    def find_completions(self, text):
        curline = readline.get_line_buffer()
        prevlines = self.statement.getvalue()
        wholestmt = prevlines + curline
        begidx = readline.get_begidx() + len(prevlines)
        endidx = readline.get_endidx() + len(prevlines)
        stuff_to_complete = wholestmt[:begidx]
        return cqlruleset.cql_complete(stuff_to_complete, text, cassandra_conn=self,
                                       debug=debug_completion, startsymbol='cqlshCommand')

    def set_prompt(self, prompt):
        self.prompt = prompt

    def cql_protect_name(self, name):
        if isinstance(name, unicode):
            name = name.encode('utf8')
        return cqlruleset.maybe_escape_name(name)

    def cql_protect_value(self, value):
        return cqlruleset.escape_value(value)

    def cql_unprotect_name(self, namestr):
        if namestr is None:
            return
        return cqlruleset.dequote_name(namestr)

    def cql_unprotect_value(self, valstr):
        if valstr is not None:
            return cqlruleset.dequote_value(valstr)

    def print_recreate_keyspace(self, ksdef, out):
        stratclass = trim_if_present(ksdef.strategy_class, 'org.apache.cassandra.locator.')
        ksname = self.cql_protect_name(ksdef.name)
        if self.cqlver_atleast(3) and not self.is_cql3_beta():
            out.write("CREATE KEYSPACE %s WITH replication = {\n" % ksname)
            out.write("  'class': %s" % self.cql_protect_value(stratclass))
            for opname, opval in ksdef.strategy_options.iteritems():
                out.write(",\n  %s: %s" % (self.cql_protect_value(opname),
                                           self.cql_protect_value(opval)))
            out.write("\n}")
            if not ksdef.durable_writes:
                out.write(" AND durable_writes = 'false'")
        else:
            out.write("CREATE KEYSPACE %s WITH strategy_class = %s"
                       % (ksname, self.cql_protect_value(stratclass)))
            for opname, opval in ksdef.strategy_options.iteritems():
                out.write("\n  AND strategy_options:%s = %s" % (opname, self.cql_protect_value(opval)))
        out.write(';\n')

        cfs = self.get_columnfamily_names(ksname)
        if cfs:
            out.write('\nUSE %s;\n' % ksname)
            for cf in cfs:
                out.write('\n')
                # yes, cf might be looked up again. oh well.
                self.print_recreate_columnfamily(ksdef.name, cf, out)

    def print_recreate_columnfamily(self, ksname, cfname, out):
        """
        Output CQL commands which should be pasteable back into a CQL session
        to recreate the given table. Can change based on CQL version in use;
        CQL 3 syntax will not be output when in CQL 2 mode, and properties
        which are deprecated with CQL 3 use (like default_validation) will not
        be output when in CQL 3 mode.

        Writes output to the given out stream.
        """

        # no metainfo available from system.schema_* for system CFs, so we have
        # to use cfdef-based description for those.

        if self.cqlver_atleast(3) and not self.is_cql3_beta():
            try:
                layout = self.get_columnfamily_layout(ksname, cfname)
            except CQL_ERRORS:
                # most likely a 1.1 beta where cql3 is supported, but not system.schema_*
                pass
            else:
                return self.print_recreate_columnfamily_from_layout(layout, out)

        cfdef = self.get_columnfamily(cfname, ksname=ksname)
        return self.print_recreate_columnfamily_from_cfdef(cfdef, out)

    def print_recreate_columnfamily_from_cfdef(self, cfdef, out):
        cfname = self.cql_protect_name(cfdef.name)
        out.write("CREATE TABLE %s (\n" % cfname)
        alias = self.cql_protect_name(cfdef.key_alias) if cfdef.key_alias else 'KEY'
        keytype = cql_typename(cfdef.key_validation_class)
        out.write("  %s %s PRIMARY KEY" % (alias, keytype))
        indexed_columns = []
        for col in cfdef.column_metadata:
            colname = self.cql_protect_name(col.name)
            out.write(",\n  %s %s" % (colname, cql_typename(col.validation_class)))
            if col.index_name is not None:
                indexed_columns.append(col)
        cf_opts = []
        for (option, thriftname) in cqlruleset.columnfamily_options:
            optval = getattr(cfdef, thriftname or option, None)
            if optval is None:
                continue
            if option in ('comparator', 'default_validation'):
                optval = cql_typename(optval)
            else:
                if option == 'row_cache_provider':
                    optval = trim_if_present(optval, 'org.apache.cassandra.cache.')
                elif option == 'compaction_strategy_class':
                    optval = trim_if_present(optval, 'org.apache.cassandra.db.compaction.')
                optval = self.cql_protect_value(optval)
            cf_opts.append((option, optval))
        for option, thriftname, _ in cqlruleset.columnfamily_map_options:
            optmap = getattr(cfdef, thriftname or option, {})
            for k, v in optmap.items():
                if option == 'compression_parameters' and k == 'sstable_compression':
                    v = trim_if_present(v, 'org.apache.cassandra.io.compress.')
                cf_opts.append(('%s:%s' % (option, k), self.cql_protect_value(v)))
        out.write('\n)')
        if cf_opts:
            joiner = 'WITH'
            for optname, optval in cf_opts:
                out.write(" %s\n  %s=%s" % (joiner, optname, optval))
                joiner = 'AND'
        out.write(";\n")

        for col in indexed_columns:
            out.write('\n')
            # guess CQL can't represent index_type or index_options
            out.write('CREATE INDEX %s ON %s (%s);\n'
                         % (col.index_name, cfname, self.cql_protect_name(col.name)))

    def print_recreate_columnfamily_from_layout(self, layout, out):
        cfname = self.cql_protect_name(layout.name)
        out.write("CREATE TABLE %s (\n" % cfname)
        keycol = layout.columns[0]
        out.write("  %s %s" % (self.cql_protect_name(keycol.name),
                               keycol.cqltype.cql_parameterized_type()))
        if len(layout.primary_key_components) == 1:
            out.write(" PRIMARY KEY")

        indexed_columns = []
        for col in layout.columns[1:]:
            colname = self.cql_protect_name(col.name)
            out.write(",\n  %s %s" % (colname, col.cqltype.cql_parameterized_type()))
            if col.index_name is not None:
                indexed_columns.append(col)

        if len(layout.primary_key_components) > 1:
            out.write(",\n  PRIMARY KEY (")
            partkeynames = map(self.cql_protect_name, layout.partition_key_components)
            if len(partkeynames) > 1:
                partkey = "(%s)" % ', '.join(partkeynames)
            else:
                partkey = partkeynames[0]
            pk_parts = [partkey] + map(self.cql_protect_name, layout.column_aliases)
            out.write(', '.join(pk_parts) + ')')

        out.write("\n)")
        joiner = 'WITH'

        if layout.compact_storage:
            out.write(' WITH COMPACT STORAGE')
            joiner = 'AND'

        # TODO: this should display CLUSTERING ORDER BY information too.
        # work out how to determine that from a layout.

        cf_opts = []
        for cql3option, layoutoption in cqlruleset.columnfamily_layout_options:
            if layoutoption is None:
                layoutoption = cql3option
            optval = getattr(layout, layoutoption, None)
            if optval is None:
                continue
            elif layoutoption == 'compaction_strategy_class':
                optval = trim_if_present(optval, 'org.apache.cassandra.db.compaction.')
            cf_opts.append((cql3option, self.cql_protect_value(optval)))
        for cql3option, layoutoption, _ in cqlruleset.columnfamily_layout_map_options:
            if layoutoption is None:
                layoutoption = cql3option
            optmap = getattr(layout, layoutoption, {})
            if layoutoption == 'compression_parameters':
                compclass = optmap.get('sstable_compression')
                if compclass is not None:
                    optmap['sstable_compression'] = \
                            trim_if_present(compclass, 'org.apache.cassandra.io.compress.')
            if self.cqlver_atleast(3) and not self.is_cql3_beta():
                cf_opts.append((cql3option, optmap))
            else:
                for k, v in optmap.items():
                    cf_opts.append(('%s:%s' % (cql3option, k.encode('ascii')),
                                    self.cql_protect_value(v)))
        if cf_opts:
            for optname, optval in cf_opts:
                if isinstance(optval, dict):
                    optval = '{%s}' % ', '.join(['%s: %s' % (self.cql_protect_value(k),
                                                             self.cql_protect_value(v))
                                                 for (k, v) in optval.items()])
                    if optval == '{}':
                        continue
                out.write(" %s\n  %s=%s" % (joiner, optname, optval))
                joiner = 'AND'
        out.write(";\n")

        for col in indexed_columns:
            out.write('\n')
            # guess CQL can't represent index_type or index_options
            out.write('CREATE INDEX %s ON %s (%s);\n'
                         % (col.index_name, cfname, self.cql_protect_name(col.name)))

    def describe_keyspace(self, ksname):
        print
        self.print_recreate_keyspace(self.get_keyspace(ksname), sys.stdout)
        print

    def describe_columnfamily(self, ksname, cfname):
        if ksname is None:
            ksname = self.current_keyspace
        print
        self.print_recreate_columnfamily(ksname, cfname, sys.stdout)
        print

    def describe_columnfamilies(self, ksname):
        print
        if ksname is None:
            for k in self.get_keyspaces():
                print 'Keyspace %s' % (k.name,)
                print '---------%s' % ('-' * len(k.name))
                cmd.Cmd.columnize(self, self.get_columnfamily_names(k.name))
                print
        else:
            cmd.Cmd.columnize(self, self.get_columnfamily_names(ksname))
            print

    def describe_cluster(self):
        print '\nCluster: %s' % self.get_cluster_name()
        p = trim_if_present(self.get_partitioner(), 'org.apache.cassandra.dht.')
        print 'Partitioner: %s' % p
        snitch = trim_if_present(self.get_snitch(), 'org.apache.cassandra.locator.')
        print 'Snitch: %s\n' % snitch
        if self.current_keyspace is not None \
        and self.current_keyspace not in SYSTEM_KEYSPACES:
            print "Range ownership:"
            ring = self.get_ring()
            for entry in ring:
                print ' %39s  [%s]' % (entry.start_token, ', '.join(entry.endpoints))
            print

    def describe_schema(self):
        print
        for k in self.get_keyspaces():
            self.print_recreate_keyspace(k, sys.stdout)
            print

    def do_describe(self, parsed):
        """
        DESCRIBE [cqlsh only]

        (DESC may be used as a shorthand.)

          Outputs information about the connected Cassandra cluster, or about
          the data stored on it. Use in one of the following ways:

        DESCRIBE KEYSPACE [<keyspacename>]

          Output CQL commands that could be used to recreate the given
          keyspace, and the tables in it. In some cases, as the CQL interface
          matures, there will be some metadata about a keyspace that is not
          representable with CQL. That metadata will not be shown.

          The '<keyspacename>' argument may be omitted when using a non-system
          keyspace; in that case, the current keyspace will be described.

        DESCRIBE TABLES

          Output the names of all tables in the current keyspace, or in all
          keyspaces if there is no current keyspace.

        DESCRIBE TABLE <tablename>

          Output CQL commands that could be used to recreate the given table.
          In some cases, as above, there may be table metadata which is not
          representable and which will not be shown.

        DESCRIBE CLUSTER

          Output information about the connected Cassandra cluster, such as the
          cluster name, and the partitioner and snitch in use. When you are
          connected to a non-system keyspace, also shows endpoint-range
          ownership information for the Cassandra ring.

        DESCRIBE SCHEMA

          Output CQL commands that could be used to recreate the entire schema.
          Works as though "DESCRIBE KEYSPACE k" was invoked for each keyspace
          k.
        """
        what = parsed.matched[1][1].lower()
        if what == 'keyspace':
            ksname = self.cql_unprotect_name(parsed.get_binding('ksname', ''))
            if not ksname:
                ksname = self.current_keyspace
                if ksname is None:
                    self.printerr('Not in any keyspace.')
                    return
            self.describe_keyspace(ksname)
        elif what in ('columnfamily', 'table'):
            ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
            cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
            self.describe_columnfamily(ks, cf)
        elif what in ('columnfamilies', 'tables'):
            self.describe_columnfamilies(self.current_keyspace)
        elif what == 'cluster':
            self.describe_cluster()
        elif what == 'schema':
            self.describe_schema()
    do_desc = do_describe

    def do_copy(self, parsed):
        r"""
        COPY [cqlsh only]

          COPY x FROM: Imports CSV data into a Cassandra table
          COPY x TO: Exports data from a Cassandra table in CSV format.

        COPY <table_name> [ ( column [, ...] ) ]
             FROM ( '<filename>' | STDIN )
             [ WITH <option>='value' [AND ...] ];

        COPY <table_name> [ ( column [, ...] ) ]
             TO ( '<filename>' | STDOUT )
             [ WITH <option>='value' [AND ...] ];

        Available options and defaults:

          DELIMITER=','    - character that appears between records
          QUOTE='"'        - quoting character to be used to quote fields
          ESCAPE='\'       - character to appear before the QUOTE char when quoted
          HEADER=false     - whether to ignore the first line
          ENCODING='utf8'  - encoding for CSV output (COPY TO only)
          NULL=''          - string that represents a null value (COPY TO only)

        When entering CSV data on STDIN, you can use the sequence "\."
        on a line by itself to end the data input.
        """
        ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
        if ks is None:
            ks = self.current_keyspace
            if ks is None:
                raise NoKeyspaceError("Not in any keyspace.")
        cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
        columns = parsed.get_binding('colnames', None)
        if columns is not None:
            columns = map(self.cql_unprotect_name, columns)
        fname = parsed.get_binding('fname', None)
        if fname is not None:
            fname = os.path.expanduser(self.cql_unprotect_value(fname))
        copyoptnames = map(str.lower, parsed.get_binding('optnames', ()))
        copyoptvals = map(self.cql_unprotect_value, parsed.get_binding('optvals', ()))
        opts = dict(zip(copyoptnames, copyoptvals))

        timestart = time.time()

        direction = parsed.get_binding('dir').upper()
        if direction == 'FROM':
            rows = self.perform_csv_import(ks, cf, columns, fname, opts)
            verb = 'imported'
        elif direction == 'TO':
            rows = self.perform_csv_export(ks, cf, columns, fname, opts)
            verb = 'exported'
        else:
            raise SyntaxError("Unknown direction %s" % direction)

        timeend = time.time()
        print "%d rows %s in %s." % (rows, verb, describe_interval(timeend - timestart))

    def perform_csv_import(self, ks, cf, columns, fname, opts):
        dialect_options = self.csv_dialect_defaults.copy()
        if 'quote' in opts:
            dialect_options['quotechar'] = opts.pop('quote')
        if 'escape' in opts:
            dialect_options['escapechar'] = opts.pop('escape')
        if 'delimiter' in opts:
            dialect_options['delimiter'] = opts.pop('delimiter')
        header = bool(opts.pop('header', '').lower() == 'true')
        if dialect_options['quotechar'] == dialect_options['escapechar']:
            dialect_options['doublequote'] = True
            del dialect_options['escapechar']
        if opts:
            self.printerr('Unrecognized COPY FROM options: %s'
                          % ', '.join(opts.keys()))
            return 0

        if fname is None:
            do_close = False
            print "[Use \. on a line by itself to end input]"
            linesource = self.use_stdin_reader(prompt='[copy] ', until=r'\.')
        else:
            do_close = True
            try:
                linesource = open(fname, 'rb')
            except IOError, e:
                self.printerr("Can't open %r for reading: %s" % (fname, e))
                return 0
        try:
            if header:
                linesource.next()
            numcol, prepq = self.prep_import_insert(ks, cf, columns)
            rownum = -1
            reader = csv.reader(linesource, **dialect_options)
            for rownum, row in enumerate(reader):
                if len(row) != numcol:
                    self.printerr("Record #%d (line %d) has the wrong number of fields "
                                  "(%d instead of %d)."
                                  % (rownum, reader.line_num, len(row), numcol))
                    return rownum
                if not self.do_import_insert(prepq, row):
                    self.printerr("Aborting import at record #%d (line %d). "
                                  "Previously-inserted values still present."
                                  % (rownum, reader.line_num))
                    return rownum
        finally:
            if do_close:
                linesource.close()
            elif self.tty:
                print
        return rownum + 1

    def prep_import_insert(self, ks, cf, columns):
        if columns is None:
            # default to all known columns
            columns = self.get_column_names(ks, cf)

        # would be nice to be able to use a prepared query here, but in order
        # to use that interface, we'd need to have all the input as native
        # values already, reading them from text just like the various
        # Cassandra cql types do. Better just to submit them all as intact
        # CQL string literals and let Cassandra do its thing.
        return len(columns), 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (
            self.cql_protect_name(ks),
            self.cql_protect_name(cf),
            ', '.join(map(self.cql_protect_name, columns))
        )

    def do_import_insert(self, prepq, rowvalues):
        valstring = ', '.join(map(self.cql_protect_value, rowvalues))
        cql = prepq % valstring
        if self.debug:
            print "Import using CQL: %s" % cql
        return self.perform_statement_untraced(cql)

    def perform_csv_export(self, ks, cf, columns, fname, opts):
        dialect_options = self.csv_dialect_defaults.copy()
        if 'quote' in opts:
            dialect_options['quotechar'] = opts.pop('quote')
        if 'escape' in opts:
            dialect_options['escapechar'] = opts.pop('escape')
        if 'delimiter' in opts:
            dialect_options['delimiter'] = opts.pop('delimiter')
        encoding = opts.pop('encoding', 'utf8')
        nullval = opts.pop('null', '')
        header = bool(opts.pop('header', '').lower() == 'true')
        if dialect_options['quotechar'] == dialect_options['escapechar']:
            dialect_options['doublequote'] = True
            del dialect_options['escapechar']

        if opts:
            self.printerr('Unrecognized COPY TO options: %s'
                          % ', '.join(opts.keys()))
            return 0

        if fname is None:
            do_close = False
            csvdest = sys.stdout
        else:
            do_close = True
            try:
                csvdest = open(fname, 'wb')
            except IOError, e:
                self.printerr("Can't open %r for writing: %s" % (fname, e))
                return 0
        try:
            self.prep_export_dump(ks, cf, columns)
            writer = csv.writer(csvdest, **dialect_options)
            if header:
                writer.writerow([d[0] for d in self.cursor.description])
            rows = 0
            while True:
                row = self.cursor.fetchone()
                if row is None:
                    break
                fmt = lambda v, t: \
                    format_value(v, t, output_encoding=encoding, nullval=nullval,
                                 time_format=self.display_time_format,
                                 float_precision=self.display_float_precision).strval
                writer.writerow(map(fmt, row, self.cursor.column_types))
                rows += 1
        finally:
            if do_close:
                csvdest.close()
        return rows

    def prep_export_dump(self, ks, cf, columns):
        if columns is None:
            columns = self.get_column_names(ks, cf)
        columnlist = ', '.join(map(self.cql_protect_name, columns))
        # this limit is pretty awful. would be better to use row-key-paging, so
        # that the dump could be pretty easily aborted if necessary, but that
        # can be kind of tricky with cql3. Punt for now, until the real cursor
        # API is added in CASSANDRA-4415.
        query = 'SELECT %s FROM %s.%s LIMIT 99999999' \
                % (columnlist, self.cql_protect_name(ks), self.cql_protect_name(cf))
        self.cursor.execute(query)

    def do_show(self, parsed):
        """
        SHOW [cqlsh only]

          Displays information about the current cqlsh session. Can be called in
          the following ways:

        SHOW VERSION

          Shows the version and build of the connected Cassandra instance, as
          well as the versions of the CQL spec and the Thrift protocol that
          the connected Cassandra instance understands.

        SHOW HOST

          Shows where cqlsh is currently connected.

        SHOW ASSUMPTIONS

          Outputs the current list of type assumptions as specified by the
          user. See the help for the ASSUME command for more information.
        """
        showwhat = parsed.get_binding('what').lower()
        if showwhat == 'version':
            self.get_connection_versions()
            self.show_version()
        elif showwhat == 'host':
            self.show_host()
        elif showwhat == 'assumptions':
            self.show_assumptions()
        else:
            self.printerr('Wait, how do I show %r?' % (showwhat,))

    def do_assume(self, parsed):
        """
        ASSUME [cqlsh only]

          Instruct cqlsh to consider certain column names or values to be of a
          specified type, even if that type information is not specified in
          the table's metadata. Data will be deserialized according to the
          given type, and displayed appropriately when retrieved.

          Use thus:

        ASSUME [<keyspace>.]<tablename> NAMES ARE <type>;

          Treat all column names in the given table as being of the
          given type.

        ASSUME [<keyspace>.]<tablename> VALUES ARE <type>;

          Treat all column values in the given table as being of the
          given type, unless there is more information about the specific
          column being deserialized. That is, a column-specific ASSUME will
          take precedence here, as will column-specific metadata in the
          table's definition.

        ASSUME [<keyspace>.]<tablename>(<colname>) VALUES ARE <type>;

          Treat all values in the given column in the given table as
          being of the specified type. This overrides any other information
          about the type of a value.

        Assign multiple overrides at once for the same table by
        separating with commas:

          ASSUME ks.table NAMES ARE uuid, VALUES ARE int, (col) VALUES ARE ascii

        See HELP TYPES for information on the supported data storage types.
        """
        ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
        cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
        colname = self.cql_unprotect_name(parsed.get_binding('colname', None))

        params = {}
        for paramname in ('names', 'values', 'colvalues'):
            val = parsed.get_binding(paramname, None)
            params[paramname] = self.cql_unprotect_value(val)
        if ks is None:
            if self.current_keyspace is None:
                self.printerr('Error: not in any keyspace.')
                return
            ks = self.current_keyspace

        for overridetype in ('names', 'values', 'colvalues'):
            cqltype = params[overridetype]
            if cqltype is None:
                continue
            try:
                validator_class = lookup_cqltype(cqltype).cass_parameterized_type()
            except KeyError:
                self.printerr('Error: validator type %s not found.' % cqltype)
            else:
                self.add_assumption(ks, cf, colname, overridetype, validator_class)

    def do_source(self, parsed):
        """
        SOURCE [cqlsh only]

        Executes a file containing CQL statements. Gives the output for each
        statement in turn, if any, or any errors that occur along the way.

        Errors do NOT abort execution of the CQL source file.

        Usage:

          SOURCE '<file>';

        That is, the path to the file to be executed must be given inside a
        string literal. The path is interpreted relative to the current working
        directory. The tilde shorthand notation ('~/mydir') is supported for
        referring to $HOME.

        See also the --file option to cqlsh.
        """
        fname = parsed.get_binding('fname')
        fname = os.path.expanduser(self.cql_unprotect_value(fname))
        try:
            f = open(fname, 'r')
        except IOError, e:
            self.printerr('Could not open %r: %s' % (fname, e))
            return
        subshell = Shell(self.hostname, self.port, self.transport_factory,
                         color=self.color, encoding=self.encoding, stdin=f,
                         tty=False, use_conn=self.conn, cqlver=self.cql_version,
                         display_time_format=self.display_time_format,
                         display_float_precision=self.display_float_precision)
        subshell.cmdloop()
        f.close()

    def do_capture(self, parsed):
        """
        CAPTURE [cqlsh only]

        Begins capturing command output and appending it to a specified file.
        Output will not be shown at the console while it is captured.

        Usage:

          CAPTURE '<file>';
          CAPTURE OFF;
          CAPTURE;

        That is, the path to the file to be appended to must be given inside a
        string literal. The path is interpreted relative to the current working
        directory. The tilde shorthand notation ('~/mydir') is supported for
        referring to $HOME.

        Only query result output is captured. Errors and output from cqlsh-only
        commands will still be shown in the cqlsh session.

        To stop capturing output and show it in the cqlsh session again, use
        CAPTURE OFF.

        To inspect the current capture configuration, use CAPTURE with no
        arguments.
        """
        fname = parsed.get_binding('fname')
        if fname is None:
            if self.shunted_query_out is not None:
                print "Currently capturing query output to %r." % (self.query_out.name,)
            else:
                print "Currently not capturing query output."
            return

        if fname.upper() == 'OFF':
            if self.shunted_query_out is None:
                self.printerr('Not currently capturing output.')
                return
            self.query_out.close()
            self.query_out = self.shunted_query_out
            self.color = self.shunted_color
            self.shunted_query_out = None
            del self.shunted_color
            return

        if self.shunted_query_out is not None:
            self.printerr('Already capturing output to %s. Use CAPTURE OFF'
                          ' to disable.' % (self.query_out.name,))
            return

        fname = os.path.expanduser(self.cql_unprotect_value(fname))
        try:
            f = open(fname, 'a')
        except IOError, e:
            self.printerr('Could not open %r for append: %s' % (fname, e))
            return
        self.shunted_query_out = self.query_out
        self.shunted_color = self.color
        self.query_out = f
        self.color = False
        print 'Now capturing query output to %r.' % (fname,)

    def do_tracing(self, parsed):
        """
        TRACING [cqlsh]

          Enables or disables request tracing.

        TRACING ON

          Enables tracing for all further requests.

        TRACING OFF

          Disables tracing.

        TRACING

          TRACING with no arguments shows the current tracing status.
        """
        if not self.cqlver_atleast(3):
            self.printerr('Tracing requires CQL version 3.0.0 or higher.')
            return

        switch = parsed.get_binding('switch')
        if switch is None:
            if self.tracing_enabled:
                print "Tracing is currently enabled. Use TRACING OFF to disable"
            else:
                print "Tracing is currently disabled. Use TRACING ON to enable."
            return

        if switch.upper() == 'ON':
            if self.tracing_enabled:
                self.printerr('Tracing is already enabled. '
                              'Use TRACING OFF to disable.')
                return
            self.tracing_enabled = True
            print 'Now tracing requests.'
            return

        if switch.upper() == 'OFF':
            if not self.tracing_enabled:
                self.printerr('Tracing is not enabled.')
                return
            self.tracing_enabled = False
            print 'Disabled tracing.'

    def do_consistency(self, parsed):
        """
        CONSISTENCY [cqlsh with CQL3 only]

           Overrides default consistency level (default level is ONE).

        CONSISTENCY <level>

           Sets consistency level for future requests.

           Valid consistency levels:

           ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_QUORUM and EACH_QUORUM.

        CONSISTENCY

           CONSISTENCY with no arguments shows the current consistency level.
        """
        if not self.cqlver_atleast(3):
            self.printerr('CONSISTENCY requires CQL version 3.0.0 or higher.')
            return

        level = parsed.get_binding('level')
        if level is None:
            print 'Current consistency level is %s.' % (self.cursor.consistency_level,)
            return

        self.cursor.consistency_level = level.upper()
        print 'Consistency level set to %s.' % (level.upper(),)

    def do_exit(self, parsed=None):
        """
        EXIT/QUIT [cqlsh only]

        Exits cqlsh.
        """
        self.stop = True
    do_quit = do_exit

    def do_debug(self, parsed):
        import pdb
        pdb.set_trace()

    def get_help_topics(self):
        topics = [ t[3:] for t in dir(self) if t.startswith('do_') and getattr(self, t, None).__doc__]
        for hide_from_help in ('quit',):
            topics.remove(hide_from_help)
        return topics

    def columnize(self, slist, *a, **kw):
        return cmd.Cmd.columnize(self, sorted([u.upper() for u in slist]), *a, **kw)

    def do_help(self, parsed):
        """
        HELP [cqlsh only]

        Gives information about cqlsh commands. To see available topics,
        enter "HELP" without any arguments. To see help on a topic,
        use "HELP <topic>".
        """
        topics = parsed.get_binding('topic', ())
        if not topics:
            shell_topics = [ t.upper() for t in self.get_help_topics() ]
            self.print_topics("\nDocumented shell commands:", shell_topics, 15, 80)
            cql_topics = [ t.upper() for t in cqldocs.get_help_topics() ]
            self.print_topics("CQL help topics:", cql_topics, 15, 80)
            return
        for t in topics:
            if t.lower() in self.get_help_topics():
                doc = getattr(self, 'do_' + t.lower()).__doc__
                self.stdout.write(doc + "\n")
            elif t.lower() in cqldocs.get_help_topics():
                cqldocs.print_help_topic(t)
            else:
                self.printerr("*** No help on %s" % (t,))

    def applycolor(self, text, color=None):
        if not color or not self.color:
            return text
        return color + text + ANSI_RESET

    def writeresult(self, text, color=None, newline=True, out=None):
        if out is None:
            out = self.query_out
        out.write(self.applycolor(str(text), color) + ('\n' if newline else ''))

    def flush_output(self):
        self.query_out.flush()

    def printerr(self, text, color=RED, newline=True, shownum=None):
        if shownum is None:
            shownum = self.show_line_nums
        if shownum:
            text = '%s:%d:%s' % (self.stdin.name, self.lineno, text)
        self.writeresult(text, color, newline=newline, out=sys.stderr)

    def add_assumption(self, ksname, cfname, colname, valtype, valclass):
        try:
            v_info = self.schema_overrides[(ksname, cfname)]
        except KeyError:
            v_info = self.schema_overrides[(ksname, cfname)] = FakeCqlMetadata()
        if valtype == 'names':
            v_info.default_name_type = valclass
        elif valtype == 'values':
            v_info.default_value_type = valclass
        elif valtype == 'colvalues':
            v_info.value_types[colname] = valclass

class FakeCqlMetadata:
    def __init__(self):
        self.name_types = {}
        self.value_types = {}
        self.default_name_type = None
        self.default_value_type = None

class OverrideableSchemaDecoder(cql.decoders.SchemaDecoder):
    def __init__(self, schema, overrides=None):
        cql.decoders.SchemaDecoder.__init__(self, schema)
        self.apply_schema_overrides(overrides)

    def apply_schema_overrides(self, overrides):
        if overrides is None:
            return
        if overrides.default_name_type is not None:
            self.schema.default_name_type = overrides.default_name_type
        if overrides.default_value_type is not None:
            self.schema.default_value_type = overrides.default_value_type
        self.schema.name_types.update(overrides.name_types)
        self.schema.value_types.update(overrides.value_types)

class ErrorHandlingSchemaDecoder(OverrideableSchemaDecoder):
    def name_decode_error(self, err, namebytes, expectedtype):
        return DecodeError(namebytes, err, expectedtype)

    def value_decode_error(self, err, namebytes, valuebytes, expectedtype):
        return DecodeError(valuebytes, err, expectedtype, colname=namebytes)

def option_with_default(cparser_getter, section, option, default=None):
    try:
        return cparser_getter(section, option)
    except ConfigParser.Error:
        return default

def raw_option_with_default(configs, section, option, default=None):
    """
    Same (almost) as option_with_default() but won't do any string interpolation.
    Useful for config values that include '%' symbol, e.g. time format string.
    """
    try:
        return configs.get(section, option, raw=True)
    except ConfigParser.Error:
        return default

def should_use_color():
    if not sys.stdout.isatty():
        return False
    if os.environ.get('TERM', 'dumb') == 'dumb':
        return False
    try:
        import subprocess
        p = subprocess.Popen(['tput', 'colors'], stdout=subprocess.PIPE)
        stdout, _ = p.communicate()
        if int(stdout.strip()) < 8:
            return False
    except (OSError, ImportError):
        # oh well, we tried. at least we know there's a $TERM and it's
        # not "dumb".
        pass
    return True

def load_factory(name):
    """
    Attempts to load a transport factory function given its fully qualified
    name, e.g. "cqlshlib.tfactory.regular_transport_factory"
    """
    parts = name.split('.')
    module = ".".join(parts[:-1])
    try:
        t = __import__(module)
        for part in parts[1:]:
            t = getattr(t, part)
        return t
    except (ImportError, AttributeError):
        sys.exit("Can't locate transport factory function %s" % name)

def read_options(cmdlineargs, environment):
    configs = ConfigParser.SafeConfigParser()
    configs.read(CONFIG_FILE)

    optvalues = optparse.Values()
    optvalues.username = option_with_default(configs.get, 'authentication', 'username')
    optvalues.password = option_with_default(configs.get, 'authentication', 'password')
    optvalues.keyspace = option_with_default(configs.get, 'authentication', 'keyspace')
    optvalues.transport_factory = option_with_default(configs.get, 'connection', 'factory',
                                                      DEFAULT_TRANSPORT_FACTORY)
    optvalues.completekey = option_with_default(configs.get, 'ui', 'completekey',
                                                DEFAULT_COMPLETEKEY)
    optvalues.color = option_with_default(configs.getboolean, 'ui', 'color')
    optvalues.time_format = raw_option_with_default(configs, 'ui', 'time_format',
                                                    DEFAULT_TIME_FORMAT)
    optvalues.float_precision = option_with_default(configs.getint, 'ui', 'float_precision',
                                                    DEFAULT_FLOAT_PRECISION)
    optvalues.debug = False
    optvalues.file = None
    optvalues.tty = sys.stdin.isatty()
    optvalues.cqlversion = option_with_default(configs.get, 'cql', 'version', DEFAULT_CQLVER)

    (options, arguments) = parser.parse_args(cmdlineargs, values=optvalues)

    hostname = option_with_default(configs.get, 'connection', 'hostname', DEFAULT_HOST)
    port = option_with_default(configs.get, 'connection', 'port', DEFAULT_PORT)

    hostname = environment.get('CQLSH_HOST', hostname)
    port = environment.get('CQLSH_PORT', port)

    if len(arguments) > 0:
        hostname = arguments[0]
    if len(arguments) > 1:
        port = arguments[1]

    if options.file is not None:
        options.tty = False

    options.transport_factory = load_factory(options.transport_factory)

    if optvalues.color in (True, False):
        options.color = optvalues.color
    else:
        if options.file is not None:
            options.color = False
        else:
            options.color = should_use_color()

    options.cqlversion, cqlvertup = full_cql_version(options.cqlversion)
    if cqlvertup[0] < 3:
        options.cqlmodule = cqlhandling
    else:
        options.cqlmodule = cql3handling

    try:
        port = int(port)
    except ValueError:
        parser.error('%r is not a valid port number.' % port)

    return options, hostname, port

def setup_cqlruleset(cqlmodule):
    global cqlruleset
    cqlruleset = cqlmodule.CqlRuleSet
    cqlruleset.append_rules(cqlsh_extra_syntax_rules)
    for rulename, termname, func in cqlsh_syntax_completers:
        cqlruleset.completer_for(rulename, termname)(func)
    cqlruleset.commands_end_with_newline.update(my_commands_ending_with_newline)

def setup_cqldocs(cqlmodule):
    global cqldocs
    cqldocs = cqlmodule.cqldocs

def init_history():
    if readline is not None:
        try:
            readline.read_history_file(HISTORY)
        except IOError:
            pass
        delims = readline.get_completer_delims()
        delims.replace("'", "")
        delims += '.'
        readline.set_completer_delims(delims)

def save_history():
    if readline is not None:
        try:
            readline.write_history_file(HISTORY)
        except IOError:
            pass

def main(options, hostname, port):
    setup_cqlruleset(options.cqlmodule)
    setup_cqldocs(options.cqlmodule)
    init_history()

    if options.file is None:
        stdin = None
    else:
        try:
            stdin = open(options.file, 'r')
        except IOError, e:
            sys.exit("Can't open %r: %s" % (options.file, e))

    if options.debug:
        import thrift
        sys.stderr.write("Using CQL driver: %s\n" % (cql,))
        sys.stderr.write("Using thrift lib: %s\n" % (thrift,))

    try:
        shell = Shell(hostname,
                      port,
                      options.transport_factory,
                      color=options.color,
                      username=options.username,
                      password=options.password,
                      stdin=stdin,
                      tty=options.tty,
                      completekey=options.completekey,
                      cqlver=options.cqlversion,
                      keyspace=options.keyspace,
                      display_time_format=options.time_format,
                      display_float_precision=options.float_precision)
    except KeyboardInterrupt:
        sys.exit('Connection aborted.')
    except CQL_ERRORS, e:
        sys.exit('Connection error: %s' % (e,))
    except VersionNotSupported, e:
        sys.exit('Unsupported CQL version: %s' % (e,))
    if options.debug:
        shell.debug = True

    shell.cmdloop()
    save_history()

if __name__ == '__main__':
    main(*read_options(sys.argv[1:], os.environ))

# vim: set ft=python et ts=4 sw=4 :
