# -*- coding: utf-8 -*-
# Management of opened sockets
#
# Copyright (C) 2011 Matěj Grégr, Michal Kajan, Libor Polčák, Vladimír Veselý
# 
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from . import li
from .li_socket import LIHeartbeatSocket
from ..tools import log, signals

import select
import socket
import sys

""" Maximal time for which a wrapped select waits """
SOCKET_READY_WAIT_TIMEOUT = 6000000

POLL_MASK_READ = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR

class LISocketManager(object):
    """ Allows to create sockets and access them """

    def __init__(self, ptpt = (), servers = (), tcpServers = (), tcpClients = (), serversQueue = 5):
        """ Constructor

        ptpt
          Iterable sequence over 2-tuple (name, role) where name is the name
          of the socket and role is either Client or Server. One point to point
          connection from Unix sockets  is created for each interface. Note
          that Client-Server pairs should be established across all
          comunicating programs. Also note that incorrect ordering may
          introduce deadlocks.
        servers
          Iterable sequence over names of Unix sockets that will stay listen for
          more than one connection.
        tcpServers
          Iterable over 3-tuples specifing TCP sockets that will listen for
          incoming network connections. 3-tuples have following format: (name,
          address, port). Name is a string that specify the name of the,
          callback, address is a local address that will be bound for listening
          (:: and 0.0.0.0 are also permited), and port is a port where the server
          listens for incoming connections.
        tcpClients
          Iterable over 3-tuples (name, server, port) specifing client TCP
          sockets. Name is a string and it is used for callback function name
          specification. Server is a string specifying  hostname or IP address
          of the server and port is a port number at the server side.
        serversQueue
          Queue size for Unix server sockets.
        """
        # Point-to-point sockets
        unixPtpt = {n:getattr(li, "ptptAs" + r)(n, LIHeartbeatSocket)
            for (n, r) in ptpt}
        tcpClients = {n:li.tcpSocketAsClient(n, s, p, LIHeartbeatSocket) for \
                (n, s, p) in tcpClients}
        self.__liPtptSockets = {}
        self.__liPtptSockets.update(unixPtpt)
        self.__liPtptSockets.update(tcpClients)
        # Server sockets
        unixServers = {n:li.socketAsServer(n,serversQueue) for n in servers}
        tcpServers = {n:li.tcpSocketAsServer(n, *sockParams) for (n, *sockParams) in tcpServers}
        self.__serverSockets = {}
        self.__serverSockets.update(unixServers)
        self.__serverSockets.update(tcpServers)
        # Prepare poller
        self.__poller = select.poll()
        self.__fd_to_socket = {}
        for s in self.getAllSockets():
            self.__registerSocket(s)

    def __registerSocket(self, s):
        """ Helper function to register socket """
        self.__poller.register(s, POLL_MASK_READ)
        self.__fd_to_socket[s.fileno()] = s

    def __unregisterSocket(self, fileno):
        """ Helper function to unregister socket """
        self.__poller.unregister(fileno)
        del self.__fd_to_socket[fileno]

    def getPtptSockets(self, name):
        """ Returns all sockets belonging to the given interface

        name
          Name of the interface
        """
        return [s for (n, s) in self.__liPtptSockets.items() if n.startswith(name)]

    def getServerSocket(self, name):
        """ Returns sever socket """
        return self.__serverSockets[name]

    def getAllSockets(self):
        """ Returns a list of all sockets known to the manager """
        return list(self.__liPtptSockets.values()) + list(self.__serverSockets.values())

    def addPtPtSocket(self, s, name):
        """ Adds new PtPt LI Socket

        s The socket
        name Its unique name
        """
        if name in self.__liPtptSockets:
            name = name + str(id(s))
        self.__liPtptSockets[name] = s
        self.__registerSocket(s)

    def tryInterfaces(self, name):
        """ Sends a message to every interface (Obsolete method, avoid its usage)

        name String identifing the ETSI function
        """
        # It looks like everything is working right, so do not send messages
        #for s in self.__liPtptSockets.values():
        #    s.send(("DEBUG", name, s.getName()))
        pass

    def removeSocket(self, s):
        """ Removes the socket from the manager """
        fileno = s.fileno()
        if fileno < 0:
            log.error("Cannot remove socket %d, all socket numbers are positive" % fileno)
            return
        for name, tracked in self.__liPtptSockets.items():
            if tracked.fileno() == fileno:
                del self.__liPtptSockets[name]
                self.__unregisterSocket(fileno)
                break
        for name, tracked in self.__serverSockets.items():
            if tracked.fileno() == fileno:
                del self.__serverSockets[name]
                self.__unregisterSocket(fileno)
                break

    def closeSockets(self):
        """ Closes all sockets

        Warning do not try to operate with sockets after call to this function.
        """
        for s in (list(self.__liPtptSockets.values()) +
                list(self.__serverSockets.values())):
            try:
                self.__unregisterSocket(s.fileno())
                s.close()
            except:
                pass

    def send(self, name, msg):
        """ Sends a message through all sockets of given interface

        name
          Name of the interface
        """
        for s in self.getPtptSockets(name):
            try:
                s.send(msg)
            except Exception as e:
                log.unhandledException("Sending messages through interfaces %s failed (socket %s): " \
                        % (name, s.getName()), e)

    def mainLoop(self, func_dict):
        """ Main loop of ETSI functions

        funcDict
            Dictionary of functions that process messages and fuction to inform
            about broken sockets. Note that the functions has to be named
            processMessageIFC and brokenSocketIFC where IFC is the
            capitalised name of the interface.
        """
        # Local functions
        def getHandler(hName, s):
            return func_dict[hName + s.getName().upper()]
        # Set signal handler
        signals.setHandlerForCommonSignals(signals.signalExceptionHandler)
        # MainLoop
        try:
            RUN = True
            while RUN:
                ready_rd, exception = self.getAvailableSockets()
                for s in exception:
                    RUN = getHandler("brokenSocket", s)(s, self)
                    if not s.isOK():
                        self.removeSocket(s)
                for s in ready_rd:
                    if not s.isPtpt():
                        try:
                            getHandler("processRequest", s)(s, self)
                        except Exception as e:
                            log.warning("processRequest %s: %s" % (s.getName(), str(e)))
                    else:
                        try:
                            msg = s.receive()
                            if msg:
                                RUN = getHandler("processMessage", s)(msg, self, s)
                        except socket.error as e:
                            log.warning("processMessage: %s" % str(e))
                            s.setNotOK()
        except signals.SignalException:
            pass # Just exit
        except Exception as e:
            log.unhandledException("MainLoop: ", e)

    def getAvailableSockets(self):
        """ Returns 2-tuple of list of sockets ready to be read and list of broken sockets

        Note: that this call may block or it might return incomplete lists of sockets
        However, it is guaranted that a subset of sockets is returned whenever at
        least one socket needs processing.

        Warning: this implementation may cause starvation and repeated call of this
        function may prefer one socket over another
        """
        notOK = [s for s in self.getAllSockets() if not s.isClosed() and not s.isOK()]
        if notOK:
            return ([], notOK)
        buffered = [s for s in self.__liPtptSockets.values() if s.hasBufferedMessage()]
        if buffered:
            return (buffered, [])
        r = self.__poller.poll(SOCKET_READY_WAIT_TIMEOUT)
        ready, broken = [], []
        for fd, event in r:
        #    if event | select.POLLERR:
        #        broken.append(self.__fd_to_socket[fd])
        #        print("Event", broken)
        #    else:
                ready.append(self.__fd_to_socket[fd])
        return (ready, broken)
