#! /usr/bin/env python3
'''Functions for manipulating entries in the DB'''

from psqtie import WhereClause

from .db_obj import (fetch, fetchone, execute, insert_into, select_from,
                     delete_from, update)
from .errors import DBError
from ..common import Shortlink
from ..logging import debug


_entry_statement = '''
    SELECT
        post.*,
        source_id,
        source.title AS source_text,
        source.url AS source_url,
        source.shortlink AS source_shortlink
    FROM post
    JOIN source USING(source_id)
    '''

def get_entries(**kwargs):
    '''Search for entries.

    NOTE that tags will not be included (see cache_tags()).
    '''

    # default values
    params = {
        'since': None,
        'until': None,
        'names': None,
        'source': None,
        'hosts': None,
        'tags': None,
        'tags_op': 'or',
        'sort_by': 'posted_at',
        'order': 'descending',
        'limit': None,
        'offset': None,
        'expand': False,
    }

    for key in params:
        if kwargs.get(key):
            params[key] = kwargs[key]


    # build where clause
    conditions = []
    values = {}

    joins = ['LEFT JOIN alt_shortlink USING(source_id)']

    if params['since']:
        conditions.append('post.posted_at >= :since')
        values['since'] = params['since']

    if params['until']:
        conditions.append('post.posted_at <= :until')
        values['until'] = params['until']

    if params['names']:
        placeholders = []
        for i, name in enumerate(params['names']):
            placeholders.append(f'name{i}')
            values[f'name{i}'] = name

        conditions.append(f'post.name IN ({", ".join(placeholders)})')

    if params['source']:
        conditions.append(
            '''(source.title = :source
                OR source.url = :source
                OR source.shortlink = :source
                OR alt_shortlink.value = :source)'''
        )
        values['source'] = params['source']

    if params['hosts']:
        host_strs = []
        for i, host in enumerate(params['hosts']):
            host_strs.append(f'name LIKE :host{i}')
            values[f'host{i}'] = f'{host}/%'

        conditions.append(f'({" OR ".join(host_strs)})')

    if params['tags']:
        if params['tags_op'].upper() == 'OR':
            tag_strs = []

            for i, tag in enumerate(params['tags']):
                tag_strs.append(f'tag.value = :tag{i}')
                values[f'tag{i}'] = tag

            # make sure the video contains any of the tags
            joins.append(f'JOIN tag ON tag.post_id = post.post_id AND '
                         f'({" OR ".join(tag_strs)})')
        else:
            # make sure the video contains all the tags
            for i, tag in enumerate(params['tags']):
                joins.append(
                    f'JOIN tag AS tag{i} ON tag{i}.post_id = post.post_id AND '
                    f'tag{i}.value = :tag{i}')
                values[f'tag{i}'] = tag

    where_clause = WhereClause(conditions=conditions) if conditions else ''

    group_by_str = 'GROUP BY post.post_id'


    # build order by
    order_col = (
        'views' if params['sort_by'] == 'views' else (
            'RANDOM()' if params['sort_by'] == 'random' else 'posted_at'
        )
    )
    direction = (
        '' if params['sort_by'] == 'random' else (
            'ASC' if params['order'].upper() == 'ASC' else 'DESC'
        )
    )
    order_by_str = f'ORDER BY {order_col} {direction}'


    # build limit
    limit_str = ''
    if params['limit']:
        limit_str = f'LIMIT {int(params["limit"])}'

    if params['offset']:
        limit_str += f' OFFSET {int(params["offset"])}'


    # build statement
    joins_str = '\n'.join(joins)

    statement = f'''
        {_entry_statement}
        {joins_str}
        {where_clause}
        {group_by_str}
        {order_by_str}
        {limit_str}
        '''

    debug('params=%s', params)
    debug('statement=%s', statement)
    debug('values=%s', values)

    return fetch(statement, values)

def get_entry(name):
    '''Return an entry by its name.

    NOTE that tags will not be included (see cache_tags()).
    '''
    statement = f'''
        {_entry_statement}
        WHERE name = ?
        '''
    return fetchone(statement, name)


def get_source(source):
    '''Return a source, if it exists.'''
    if not source:
        return None

    statement = '''
        SELECT source.*
        FROM source
        LEFT JOIN alt_shortlink USING(source_id)
        WHERE (
            title = :source
            OR url = :source
            OR shortlink = :source
            OR alt_shortlink.value = :source
        )
        '''
    return fetchone(statement, source=source)


def get_source_videos(source_id):
    '''Return all videos from a source.'''
    if not source_id:
        return None

    statement = f'''
        {_entry_statement}
        WHERE source_id = ?
        '''

    return list(fetch(statement, source_id))


def get_all_sources():
    '''Return all sources including the number of videos in each.'''
    statement = '''
        SELECT *
        FROM (
            SELECT
                source.*,
                COUNT(post_id) AS videos
            FROM source
            JOIN post USING(source_id)
            GROUP BY source_id
        ) AS subquery
        ORDER BY videos DESC, title ASC
        '''
    return list(fetch(statement))


def cache_tags(entries):
    '''Generate tag list for every entry.'''
    post_ids = [str(int(e['post_id'])) for e in entries]

    # index by post_id so we can set the tags lists efficiently
    entries_dict = {e['post_id']: e for e in entries}

    statement = f'''
        SELECT
            post_id,
            GROUP_CONCAT(value, ', ') AS tags
        FROM (
            SELECT
                post_id,
                value
            FROM tag
            WHERE post_id IN ({", ".join(post_ids)})
            ORDER BY value
        ) AS subquery
        GROUP BY post_id
        '''

    for row in fetch(statement):
        entries_dict[row['post_id']]['tags'] = row['tags']


def get_videos_per_tag(tags=None):
    '''Return the number of videos per tag.'''

    where_str = ''
    if tags:
        placeholders = ['?']*len(tags)
        where_str = f'WHERE value IN ({", ".join(placeholders)})'
    else:
        tags = []

    statement = f'''
        SELECT *
        FROM (
            SELECT
                value,
                COUNT(*) AS qty
            FROM tag
            {where_str}
            GROUP BY value
        ) AS subquery
        ORDER BY qty DESC, value ASC
        '''

    return {r['value']: r['qty'] for r in fetch(statement, *tags)}


def insert_source(title, url):
    '''Insert a new source.'''
    source = {
        'title': title,
        'url': url or None,
    }

    shortlinks = Shortlink.get_source_shortlinks(source)

    for key in ('shortlink', 'url_shortlink'):
        source[key] = shortlinks[key]

    # insert source
    insert_into('source', source)
    source = select_from('source', source)

    if not source:
        raise DBError('Failed to insert source')

    # insert alternative shortlinks
    for shortlink in shortlinks['alt_shortlinks']:
        insert_into('alt_shortlink',
                    {'source_id': source['source_id'], 'value': shortlink})

    return source


def insert_entry(entry):
    '''Insert an entry into the DB.

    The expected input format is the same as what is output by the json api.

    NOTE that auto-commit should be disabled before calling this function, and
    also that the caller is responsible for committing all changes.
    '''
    for key in ('source_text', 'name', 'posted_at'):
        if not entry.get(key):
            raise ValueError(f'Malformed entry (missing key {key}): {entry}')

    # check if the entry already exists
    if get_entry(entry['name']):
        return

    # insert source
    source = get_source(entry['source_text'])
    if not source:
        assert not get_source(entry.get('source_url')), (
            f'Source URL matches multiple titles: {entry["source_url"]}')

        source = insert_source(entry['source_text'], entry.get('source_url'))

    # insert post
    post = {
        'source_id': source['source_id'],
        'name': entry['name'],
        'posted_at': entry['posted_at'],
        # make sure we don't set this to null
        'views': entry.get('views') or 0,
    }

    insert_into('post', post)

    # insert tags, if any
    if entry.get('tags'):
        post_id = get_entry(entry['name'])['post_id']

        for tag in entry['tags']:
            insert_into('tag', {'post_id': post_id, 'value': tag})


def increment_video_views(name):
    '''Increment a video's views.'''
    # make sure the entry exists, and also get the current view count
    entry = get_entry(name)
    if not entry:
        return None

    statement = 'UPDATE post SET views = views + 1 WHERE post_id = ?'
    execute(statement, entry['post_id'])

    return entry['views'] + 1


def update_video_source(entry, source_id):
    '''Update a video's source. Remove the old source if it's now empty.'''
    old_source_id = entry['source_id']

    statement = 'UPDATE post SET source_id = ? WHERE post_id = ?'
    values = [source_id, entry['post_id']]

    execute(statement, *values)

    if not get_source_videos(old_source_id):
        for table in ['alt_shortlink', 'source']:
            delete_from(table, {'source_id': old_source_id})


def insert_tag(post_id, value):
    '''Insert a new tag for a video.

    WARNING: The caller is responsible for checking that the tag doesn't already
    exist.'''
    insert_into('tag', {'post_id': post_id, 'value': value})


def delete_tag(post_id, value):
    '''Remove a tag from a video.'''
    delete_from('tag', {'post_id': post_id, 'value': value})
