#!/usr/bin/python3
import argparse
import collections
import json
import os
import socket
import time
from threading import Thread
from datetime import datetime

import mpv


import logging
logger = logging.getLogger('t_for_time')

SYNC_TOLERANCE = 0.05
SYNC_GRACE_TIME = 5
SYNC_JUMP_AHEAD = 1
PORT = 9067
DEBUG = False
FONT = 'Menlo'
FONT_SIZE = 30
FONT_BORDER = 4
SUB_MARGIN = 2 * 36 + 6


def hide_gnome_overview():
    import dbus
    bus = dbus.SessionBus()
    shell = bus.get_object('org.gnome.Shell', '/org/gnome/Shell')
    props = dbus.Interface(shell, 'org.freedesktop.DBus.Properties')
    props.Set('org.gnome.Shell', 'OverviewActive', False)


def mpv_log(loglevel, component, message):
    logger.info('[{}] {}: {}'.format(loglevel, component, message))


class Main:
    playlist_current_pos = -1
    time_pos = -1

class Sync(Thread):
    active = True
    is_main = True
    ready = False
    destination = "255.255.255.255"
    reload_check = None
    _pos = None
    _tick = 0
    need_to_sync = False

    def __init__(self, *args, **kwargs):
        self.is_main = kwargs.get('mode', 'main') == 'main'
        self.sock = self.init_socket()
        self.main = Main()
        if self.is_main:
            self.socket_enable_broadcast()

        if mpv.MPV_VERSION >= (2, 2):
            self.mpv = mpv.MPV(
                log_handler=mpv_log, input_default_bindings=True,
                input_vo_keyboard=True,
                sub_font_size=FONT_SIZE, sub_font=FONT,
                sub_border_size=FONT_BORDER,
                sub_margin_y=SUB_MARGIN,
            )
        else:
            self.mpv = mpv.MPV(
                log_handler=mpv_log, input_default_bindings=True,
                input_vo_keyboard=True,
                sub_text_font_size=FONT_SIZE, sub_text_font=FONT,
                sub_border_size=FONT_BORDER,
                sub_margin_y=SUB_MARGIN,
            )
        self.mpv.observe_property('time-pos', self.time_pos_cb)
        self.mpv.fullscreen = kwargs.get('fullscreen', False)
        self.mpv.loop_file = False
        self.mpv.loop_playlist = True
        self.mpv.register_key_binding('q', self.q_binding)
        self.playlist = kwargs['playlist']
        self.playlist_mtime = os.stat(self.playlist).st_mtime
        self.mpv.loadlist(self.playlist)
        logger.error("loaded paylist: %s", self.playlist)
        logger.debug("current playlist: %s", json.dumps(self.mpv.playlist, indent=2))
        self.deviations = collections.deque(maxlen=10)
        if not self.is_main:
            self.mpv.pause = False
            time.sleep(0.1)
            self.mpv.pause = True
            self.sync_to_main()
        self.ready = True
        Thread.__init__(self)
        self.start()

    def run(self):
        while self.active:
            if self.is_main:
                time.sleep(0.5)
            else:
                if self.need_to_sync:
                    self.sync_to_main()
                    self.deviations = collections.deque(maxlen=10)
                    self.need_to_sync = False
                else:
                    self.read_position_main()
            self.reload_playlist()
            if self._tick and abs(time.time() - self._tick) > 60:
                logger.error("player is stuck")
                self._tick = 0
                self.stop()
                self.mpv.stop()

    def q_binding(self, *args):
        self.stop()
        self.mpv.stop()

    def stop(self, *args):
        self.active = False
        if self.sock:
            self.sock.close()
            self.sock = None

    def time_pos_cb(self, pos, *args, **kwargs):
        self._tick = time.time()
        if self.is_main:
            self.send_position_local()
        elif self.ready:
            self.adjust_position()
        if self._pos != self.mpv.playlist_current_pos:
            self._pos = self.mpv.playlist_current_pos
            self.deviations = collections.deque(maxlen=10)
            self.need_to_sync = False
            try:
                track = self.mpv.playlist[self._pos]
                logger.error("%s %s", datetime.now(), track["filename"])
            except:
                pass

    def reload_playlist(self):
        if not self.reload_check:
            self.reload_check = time.time()
        if time.time() - self.reload_check > 5:
            self.reload_check = time.time()
            playlist_mtime = os.stat(self.playlist).st_mtime
            if self.playlist_mtime != playlist_mtime:
                self.playlist_mtime = playlist_mtime
                #self.mpv.loadlist(self.playlist)
                with open(self.playlist) as fd:
                    items = fd.read().strip().split('\n')
                    base = os.path.dirname(self.playlist)
                    items = [os.path.join(base, item) for item in items]
                current_items = self.mpv.playlist_filenames
                for filename in items:
                    if filename not in current_items:
                        self.mpv.playlist_append(filename)
                        logger.error("add: %s", filename)
                remove = []
                for filename in current_items:
                    if filename not in items:
                        remove.append(filename)
                for filename in remove:
                    for idx, item in enumerate(self.mpv.playlist):
                        if item["filename"] == filename:
                            logger.error("remove: %s %s", idx, filename)
                            self.mpv.playlist_remove(idx)
                            break
                for idx, filename in enumerate(items):
                    current_idx = self.mpv.playlist_filenames.index(filename)
                    if idx != current_idx:
                        logger.error("move item %s %s -> %s", filename, current_idx, idx)
                        self.mpv.playlist_move(current_idx, idx)
                logger.error("reloaded paylist: %s", self.playlist)
                logger.debug("current playlist: %s", json.dumps(self.mpv.playlist, indent=2))

    def init_socket(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, 0)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.bind(("0.0.0.0", PORT))
        return sock

    #
    # main specific
    #
    def socket_enable_broadcast(self):
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
        self.sock.connect((self.destination, PORT))

    def send_position_local(self):
        if not self.active:
            return
        try:
            msg = (
                "%0.4f %s"
                % (self.mpv.time_pos, self.mpv.playlist_current_pos)
            ).encode()
        except:
            return
        try:
            self.sock.send(msg)
        except socket.error as e:
            logger.error("send failed: %s", e)

    #
    # follower specific
    #

    def read_position_main(self):
        self.sock.settimeout(5)
        try:
            data = self.sock.recvfrom(1024)[0].decode().split(" ", 1)
        except socket.timeout:
            logger.error("failed to receive data from main")
        except OSError:
            logger.error("socket closed")
        else:
            self.main.time_pos = float(data[0])
            self.main.playlist_current_pos = int(data[1])

    def adjust_position(self):
        if self.mpv.time_pos is not None:
            try:
                deviation = self.main.time_pos - self.mpv.time_pos
            except:
                return
            self.deviations.append(deviation)
            median_deviation = self.median(list(self.deviations))
            frames = deviation / 0.04
            median_frames = median_deviation / 0.04
            if abs(deviation) <= 0.04 and self.mpv.speed != 1.0:
                self.mpv.speed = 1.0
                logger.error(
                    '%0.05f back to normal speed %0.05f (%d) median %0.05f (%d) -> %s' % (self.mpv.time_pos, deviation, frames, median_deviation, median_frames, self.mpv.speed)
                )
            if time.time() - self.last_sync > SYNC_GRACE_TIME and abs(median_deviation) > SYNC_TOLERANCE:
                if abs(median_deviation) < 1:
                    step = 0.02
                    if median_deviation > 0:
                        self.mpv.speed += step
                    else:
                        self.mpv.speed -= step
                    logger.error(
                        '%0.05f need to adjust speed %0.05f (%d) median %0.05f (%d) -> %s' % (self.mpv.time_pos, deviation, frames, median_deviation, median_frames, self.mpv.speed)
                    )
                    self.need_to_sync = False
                    self.deviations = collections.deque(maxlen=10)
                    self.last_sync = time.time()
                elif self.mpv.time_pos > 2 and not self.need_to_sync:
                    logger.error(
                        '%0.05f need to sync %0.05f (%d)  median %0.05f (%d)' % (self.mpv.time_pos, deviation, frames, median_deviation, median_frames)
                    )
                    self.need_to_sync = True

    def median(self, lst):
        quotient, remainder = divmod(len(lst), 2)
        if remainder:
            return sorted(lst)[quotient]
        return float(sum(sorted(lst)[quotient - 1:quotient + 1]) / 2.0)

    def sync_to_main(self):
        logger.error('sync to main')
        self.read_position_main()
        #print(self.main.playlist_current_pos)
        if self.main.playlist_current_pos != self.mpv.playlist_current_pos:
            self.mpv.playlist_play_index(self.main.playlist_current_pos)
            self.mpv.pause = False
            self.mpv.wait_until_playing()
            try:
                track = self.mpv.playlist[self.mpv.playlist_current_pos]
                logger.error("%s %s", datetime.now(), track["filename"])
            except:
                pass
        self.mpv.pause = True
        self.mpv.speed = 1
        pos = self.main.time_pos + SYNC_JUMP_AHEAD
        #print(pos, self.mpv.playlist_current_pos, self.mpv.time_pos)
        self.mpv.seek(pos, 'absolute', 'exact')
        time.sleep(0.1)
        self.read_position_main()
        sync_timer = time.time() # - 10 * 0.04
        deviation = self.main.time_pos - self.mpv.time_pos
        while self.active:
            #print(deviation, abs(deviation) - (time.time() - sync_timer))
            if abs(deviation) - (time.time() - sync_timer) < 0:
                self.mpv.pause = False
                try:
                    track = self.mpv.playlist[self.mpv.playlist_current_pos]
                    logger.error("%s %s %s", datetime.now(), track["filename"], pos)
                except:
                    pass
                break
        self.last_sync = time.time()


def main():
    prefix = os.path.expanduser('~/Videos/t_for_time')

    parser = argparse.ArgumentParser(description='t_for_time sync player')
    parser.add_argument('--mode', help='peer or main', default="peer")
    parser.add_argument('--playlist', default='/srv/t_for_time/render/128/front.m3u', help="m3u")
    parser.add_argument('--prefix', help='video location', default=prefix)
    parser.add_argument('--window', action='store_true', help='run in window', default=False)
    parser.add_argument('--debug', action='store_true', help='debug', default=False)
    args = parser.parse_args()

    DEBUG = args.debug
    if DEBUG:
        log_format = '%(asctime)s:%(levelname)s:%(name)s:%(message)s'
        logging.basicConfig(level=logging.DEBUG, format=log_format)
    base = os.path.dirname(os.path.abspath(__file__))
    #os.chdir(base)

    player = Sync(mode=args.mode, playlist=args.playlist, fullscreen=not args.window)
    while player.active:
        try:
            player.mpv.wait_for_playback()
        except:
            break
    player.stop()


if __name__ == "__main__":
    try:
        hide_gnome_overview()
    except:
        pass
    main()