import re
import json
import warnings
import contextlib

from pymysql.err import (
    Warning, Error, InterfaceError, DataError,
    DatabaseError, OperationalError, IntegrityError, InternalError,
    NotSupportedError, ProgrammingError)

from .log import logger
from .connection import FIELD_TYPE

# https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18

#: Regular expression for :meth:`Cursor.executemany`.
#: executemany only supports simple bulk insert.
#: You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(
    r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
    r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
    r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
    re.IGNORECASE | re.DOTALL)


class Cursor:
    """Cursor is used to interact with the database."""

    #: Max statement size which :meth:`executemany` generates.
    #:
    #: Max size of allowed statement is max_allowed_packet -
    # packet_header_size.
    #: Default value of max_allowed_packet is 1048576.
    max_stmt_length = 1024000

    def __init__(self, connection, echo=False):
        """Do not create an instance of a Cursor yourself. Call
        connections.Connection.cursor().
        """
        self._connection = connection
        self._loop = self._connection.loop
        self._description = None
        self._rownumber = 0
        self._rowcount = -1
        self._arraysize = 1
        self._executed = None
        self._result = None
        self._rows = None
        self._lastrowid = None
        self._echo = echo

    @property
    def connection(self):
        """This read-only attribute return a reference to the Connection
        object on which the cursor was created."""
        return self._connection

    @property
    def description(self):
        """This read-only attribute is a sequence of 7-item sequences.

        Each of these sequences is a collections.namedtuple containing
        information describing one result column:

        0.  name: the name of the column returned.
        1.  type_code: the type of the column.
        2.  display_size: the actual length of the column in bytes.
        3.  internal_size: the size in bytes of the column associated to
            this column on the server.
        4.  precision: total number of significant digits in columns of
            type NUMERIC. None for other types.
        5.  scale: count of decimal digits in the fractional part in
            columns of type NUMERIC. None for other types.
        6.  null_ok: always None as not easy to retrieve from the libpq.

        This attribute will be None for operations that do not
        return rows or if the cursor has not had an operation invoked
        via the execute() method yet.
        """
        return self._description

    @property
    def rowcount(self):
        """Returns the number of rows that has been produced of affected.

        This read-only attribute specifies the number of rows that the
        last :meth:`execute` produced (for Data Query Language
        statements like SELECT) or affected (for Data Manipulation
        Language statements like UPDATE or INSERT).

        The attribute is -1 in case no .execute() has been performed
        on the cursor or the row count of the last operation if it
        can't be determined by the interface.
        """
        return self._rowcount

    @property
    def rownumber(self):
        """Row index.

        This read-only attribute provides the current 0-based index of the
        cursor in the result set or ``None`` if the index cannot be
        determined.
        """

        return self._rownumber

    @property
    def arraysize(self):
        """How many rows will be returned by fetchmany() call.

        This read/write attribute specifies the number of rows to
        fetch at a time with fetchmany(). It defaults to
        1 meaning to fetch a single row at a time.

        """
        return self._arraysize

    @arraysize.setter
    def arraysize(self, val):
        """How many rows will be returned by fetchmany() call.

        This read/write attribute specifies the number of rows to
        fetch at a time with fetchmany(). It defaults to
        1 meaning to fetch a single row at a time.

        """
        self._arraysize = val

    @property
    def lastrowid(self):
        """This read-only property returns the value generated for an
        AUTO_INCREMENT column by the previous INSERT or UPDATE statement
        or None when there is no such value available. For example,
        if you perform an INSERT into a table that contains an AUTO_INCREMENT
        column, lastrowid returns the AUTO_INCREMENT value for the new row.
        """
        return self._lastrowid

    @property
    def echo(self):
        """Return echo mode status."""
        return self._echo

    @property
    def closed(self):
        """The readonly property that returns ``True`` if connections was
        detached from current cursor
        """
        return True if not self._connection else False

    async def close(self):
        """Closing a cursor just exhausts all remaining data."""
        conn = self._connection
        if conn is None:
            return
        try:
            while (await self.nextset()):
                pass
        finally:
            self._connection = None

    def _get_db(self):
        if not self._connection:
            raise ProgrammingError("Cursor closed")
        return self._connection

    def _check_executed(self):
        if not self._executed:
            raise ProgrammingError("execute() first")

    def _conv_row(self, row):
        return row

    def setinputsizes(self, *args):
        """Does nothing, required by DB API."""

    def setoutputsizes(self, *args):
        """Does nothing, required by DB API."""

    async def nextset(self):
        """Get the next query set"""
        conn = self._get_db()
        current_result = self._result
        if current_result is None or current_result is not conn._result:
            return
        if not current_result.has_next:
            return
        self._result = None
        self._clear_result()
        await conn.next_result()
        await self._do_get_result()
        return True

    def _escape_args(self, args, conn):
        if isinstance(args, (tuple, list)):
            return tuple(conn.escape(arg) for arg in args)
        elif isinstance(args, dict):
            return {key: conn.escape(val) for (key, val) in args.items()}
        else:
            # If it's not a dictionary let's try escaping it anyways.
            # Worst case it will throw a Value error
            return conn.escape(args)

    def mogrify(self, query, args=None):
        """ Returns the exact string that is sent to the database by calling
        the execute() method. This method follows the extension to the DB
        API 2.0 followed by Psycopg.

        :param query: ``str`` sql statement
        :param args: ``tuple`` or ``list`` of arguments for sql query
        """
        conn = self._get_db()
        if args is not None:
            query = query % self._escape_args(args, conn)
        return query

    async def execute(self, query, args=None):
        """Executes the given operation

        Executes the given operation substituting any markers with
        the given parameters.

        For example, getting all rows where id is 5:
          cursor.execute("SELECT * FROM t1 WHERE id = %s", (5,))

        :param query: ``str`` sql statement
        :param args: ``tuple`` or ``list`` of arguments for sql query
        :returns: ``int``, number of rows that has been produced of affected
        """
        conn = self._get_db()

        while (await self.nextset()):
            pass

        if args is not None:
            query = query % self._escape_args(args, conn)

        await self._query(query)
        self._executed = query
        if self._echo:
            logger.info(query)
            logger.info("%r", args)
        return self._rowcount

    async def executemany(self, query, args):
        """Execute the given operation multiple times

        The executemany() method will execute the operation iterating
        over the list of parameters in seq_params.

        Example: Inserting 3 new employees and their phone number

            data = [
                ('Jane','555-001'),
                ('Joe', '555-001'),
                ('John', '555-003')
                ]
            stmt = "INSERT INTO employees (name, phone) VALUES ('%s','%s')"
            await cursor.executemany(stmt, data)

        INSERT or REPLACE statements are optimized by batching the data,
        that is using the MySQL multiple rows syntax.

        :param query: `str`, sql statement
        :param args: ``tuple`` or ``list`` of arguments for sql query
        """
        if not args:
            return

        if self._echo:
            logger.info("CALL %s", query)
            logger.info("%r", args)

        m = RE_INSERT_VALUES.match(query)
        if m:
            q_prefix = m.group(1) % ()
            q_values = m.group(2).rstrip()
            q_postfix = m.group(3) or ''
            assert q_values[0] == '(' and q_values[-1] == ')'
            return (await self._do_execute_many(
                q_prefix, q_values, q_postfix, args, self.max_stmt_length,
                self._get_db().encoding))
        else:
            rows = 0
            for arg in args:
                await self.execute(query, arg)
                rows += self._rowcount
            self._rowcount = rows
        return self._rowcount

    async def _do_execute_many(self, prefix, values, postfix, args,
                               max_stmt_length, encoding):
        conn = self._get_db()
        escape = self._escape_args
        if isinstance(prefix, str):
            prefix = prefix.encode(encoding)
        if isinstance(postfix, str):
            postfix = postfix.encode(encoding)
        sql = bytearray(prefix)
        args = iter(args)
        v = values % escape(next(args), conn)
        if isinstance(v, str):
            v = v.encode(encoding, 'surrogateescape')
        sql += v
        rows = 0
        for arg in args:
            v = values % escape(arg, conn)
            if isinstance(v, str):
                v = v.encode(encoding, 'surrogateescape')
            if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
                r = await self.execute(sql + postfix)
                rows += r
                sql = bytearray(prefix)
            else:
                sql += b','
            sql += v
        r = await self.execute(sql + postfix)
        rows += r
        self._rowcount = rows
        return rows

    async def callproc(self, procname, args=()):
        """Execute stored procedure procname with args

        Compatibility warning: PEP-249 specifies that any modified
        parameters must be returned. This is currently impossible
        as they are only available by storing them in a server
        variable and then retrieved by a query. Since stored
        procedures return zero or more result sets, there is no
        reliable way to get at OUT or INOUT parameters via callproc.
        The server variables are named @_procname_n, where procname
        is the parameter above and n is the position of the parameter
        (from zero). Once all result sets generated by the procedure
        have been fetched, you can issue a SELECT @_procname_0, ...
        query using .execute() to get any OUT or INOUT values.

        Compatibility warning: The act of calling a stored procedure
        itself creates an empty result set. This appears after any
        result sets generated by the procedure. This is non-standard
        behavior with respect to the DB-API. Be sure to use nextset()
        to advance through all result sets; otherwise you may get
        disconnected.

        :param procname: ``str``, name of procedure to execute on server
        :param args: `sequence of parameters to use with procedure
        :returns: the original args.
        """
        conn = self._get_db()
        if self._echo:
            logger.info("CALL %s", procname)
            logger.info("%r", args)

        for index, arg in enumerate(args):
            q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
            await self._query(q)
            await self.nextset()

        _args = ','.join('@_%s_%d' % (procname, i) for i in range(len(args)))
        q = f"CALL {procname}({_args})"
        await self._query(q)
        self._executed = q
        return args

    def fetchone(self):
        """Fetch the next row """
        self._check_executed()
        fut = self._loop.create_future()

        if self._rows is None or self._rownumber >= len(self._rows):
            fut.set_result(None)
            return fut
        result = self._rows[self._rownumber]
        self._rownumber += 1

        fut = self._loop.create_future()
        fut.set_result(result)
        return fut

    def fetchmany(self, size=None):
        """Returns the next set of rows of a query result, returning a
        list of tuples. When no more rows are available, it returns an
        empty list.

        The number of rows returned can be specified using the size argument,
        which defaults to one

        :param size: ``int`` number of rows to return
        :returns: ``list`` of fetched rows
        """
        self._check_executed()
        fut = self._loop.create_future()
        if self._rows is None:
            fut.set_result([])
            return fut
        end = self._rownumber + (size or self._arraysize)
        result = self._rows[self._rownumber:end]
        self._rownumber = min(end, len(self._rows))

        fut.set_result(result)
        return fut

    def fetchall(self):
        """Returns all rows of a query result set

        :returns: ``list`` of fetched rows
        """
        self._check_executed()
        fut = self._loop.create_future()
        if self._rows is None:
            fut.set_result([])
            return fut

        if self._rownumber:
            result = self._rows[self._rownumber:]
        else:
            result = self._rows
        self._rownumber = len(self._rows)

        fut.set_result(result)
        return fut

    def scroll(self, value, mode='relative'):
        """Scroll the cursor in the result set to a new position according
         to mode.

        If mode is relative (default), value is taken as offset to the
        current position in the result set, if set to absolute, value
        states an absolute target position. An IndexError should be raised in
        case a scroll operation would leave the result set. In this case,
        the cursor position is left undefined (ideal would be to
        not move the cursor at all).

        :param int value: move cursor to next position according to mode.
        :param str mode: scroll mode, possible modes: `relative` and `absolute`
        """
        self._check_executed()
        if mode == 'relative':
            r = self._rownumber + value
        elif mode == 'absolute':
            r = value
        else:
            raise ProgrammingError("unknown scroll mode %s" % mode)

        if not (0 <= r < len(self._rows)):
            raise IndexError("out of range")
        self._rownumber = r

        fut = self._loop.create_future()
        fut.set_result(None)
        return fut

    async def _query(self, q):
        conn = self._get_db()
        self._last_executed = q
        self._clear_result()
        await conn.query(q)
        await self._do_get_result()

    def _clear_result(self):
        self._rownumber = 0
        self._result = None

        self._rowcount = 0
        self._description = None
        self._lastrowid = None
        self._rows = None

    async def _do_get_result(self):
        conn = self._get_db()
        self._rownumber = 0
        self._result = result = conn._result
        self._rowcount = result.affected_rows
        self._description = result.description
        self._lastrowid = result.insert_id
        self._rows = result.rows

        if result.warning_count > 0:
            await self._show_warnings(conn)

    async def _show_warnings(self, conn):
        if self._result and self._result.has_next:
            return
        ws = await conn.show_warnings()
        if ws is None:
            return
        for w in ws:
            msg = w[-1]
            warnings.warn(str(msg), Warning, 4)

    Warning = Warning
    Error = Error
    InterfaceError = InterfaceError
    DatabaseError = DatabaseError
    DataError = DataError
    OperationalError = OperationalError
    IntegrityError = IntegrityError
    InternalError = InternalError
    ProgrammingError = ProgrammingError
    NotSupportedError = NotSupportedError

    def __aiter__(self):
        return self

    async def __anext__(self):
        ret = await self.fetchone()
        if ret is not None:
            return ret
        else:
            raise StopAsyncIteration  # noqa

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.close()
        return


class _DeserializationCursorMixin:
    async def _do_get_result(self):
        await super()._do_get_result()
        if self._rows:
            self._rows = [self._deserialization_row(r) for r in self._rows]

    def _deserialization_row(self, row):
        if row is None:
            return None
        if isinstance(row, dict):
            dict_flag = True
        else:
            row = list(row)
            dict_flag = False
        for index, (name, field_type, *n) in enumerate(self._description):
            if field_type == FIELD_TYPE.JSON:
                point = name if dict_flag else index
                with contextlib.suppress(ValueError, TypeError):
                    row[point] = json.loads(row[point])
        if dict_flag:
            return row
        else:
            return tuple(row)

    def _conv_row(self, row):
        if row is None:
            return None
        row = super()._conv_row(row)
        return self._deserialization_row(row)


class DeserializationCursor(_DeserializationCursorMixin, Cursor):
    """A cursor automatic deserialization of json type fields"""


class _DictCursorMixin:
    # You can override this to use OrderedDict or other dict-like types.
    dict_type = dict

    async def _do_get_result(self):
        await super()._do_get_result()
        fields = []
        if self._description:
            for f in self._result.fields:
                name = f.name
                if name in fields:
                    name = f.table_name + '.' + name
                fields.append(name)
            self._fields = fields

        if fields and self._rows:
            self._rows = [self._conv_row(r) for r in self._rows]

    def _conv_row(self, row):
        if row is None:
            return None
        row = super()._conv_row(row)
        return self.dict_type(zip(self._fields, row))


class DictCursor(_DictCursorMixin, Cursor):
    """A cursor which returns results as a dictionary"""


class SSCursor(Cursor):
    """Unbuffered Cursor, mainly useful for queries that return a lot of
    data, or for connections to remote servers over a slow network.

    Instead of copying every row of data into a buffer, this will fetch
    rows as needed. The upside of this, is the client uses much less memory,
    and rows are returned much faster when traveling over a slow network,
    or if the result set is very big.

    There are limitations, though. The MySQL protocol doesn't support
    returning the total number of rows, so the only way to tell how many rows
    there are is to iterate over every row returned. Also, it currently isn't
    possible to scroll backwards, as only the current row is held in memory.
    """

    async def close(self):
        conn = self._connection
        if conn is None:
            return

        if self._result is not None and self._result is conn._result:
            await self._result._finish_unbuffered_query()

        try:
            while (await self.nextset()):
                pass
        finally:
            self._connection = None

    async def _query(self, q):
        conn = self._get_db()
        self._last_executed = q
        await conn.query(q, unbuffered=True)
        await self._do_get_result()
        return self._rowcount

    async def _read_next(self):
        """Read next row """
        row = await self._result._read_rowdata_packet_unbuffered()
        row = self._conv_row(row)
        return row

    async def fetchone(self):
        """ Fetch next row """
        self._check_executed()
        row = await self._read_next()
        if row is None:
            return
        self._rownumber += 1
        return row

    async def fetchall(self):
        """Fetch all, as per MySQLdb. Pretty useless for large queries, as
        it is buffered.
        """
        rows = []
        while True:
            row = await self.fetchone()
            if row is None:
                break
            rows.append(row)
        return rows

    async def fetchmany(self, size=None):
        """Returns the next set of rows of a query result, returning a
        list of tuples. When no more rows are available, it returns an
        empty list.

        The number of rows returned can be specified using the size argument,
        which defaults to one

        :param size: ``int`` number of rows to return
        :returns: ``list`` of fetched rows
        """
        self._check_executed()
        if size is None:
            size = self._arraysize

        rows = []
        for i in range(size):
            row = await self._read_next()
            if row is None:
                break
            rows.append(row)
            self._rownumber += 1
        return rows

    async def scroll(self, value, mode='relative'):
        """Scroll the cursor in the result set to a new position
        according to mode . Same as :meth:`Cursor.scroll`, but move cursor
        on server side one by one row. If you want to move 20 rows forward
        scroll will make 20 queries to move cursor. Currently only forward
        scrolling is supported.

        :param int value: move cursor to next position according to mode.
        :param str mode: scroll mode, possible modes: `relative` and `absolute`
        """

        self._check_executed()

        if mode == 'relative':
            if value < 0:
                raise NotSupportedError("Backwards scrolling not supported "
                                        "by this cursor")

            for _ in range(value):
                await self._read_next()
            self._rownumber += value
        elif mode == 'absolute':
            if value < self._rownumber:
                raise NotSupportedError(
                    "Backwards scrolling not supported by this cursor")

            end = value - self._rownumber
            for _ in range(end):
                await self._read_next()
            self._rownumber = value
        else:
            raise ProgrammingError("unknown scroll mode %s" % mode)


class SSDictCursor(_DictCursorMixin, SSCursor):
    """An unbuffered cursor, which returns results as a dictionary """
