#!/usr/bin/env python
# -*- coding: utf-8 -*-

import socket
import sys
import os
import re
import time
import select
import json
import threading
from modules.shared.cid import *

consumers = []
heartbeatThreadRunning = True
db = {}
interceptions = []
revisionNumber = 1

class SocketData(object):
    def __init__(self):
        self.buffer = ""
        self.zeroByteMessage = False

    def reset(self):
        self.buffer = ""
        self.zeroByteMessage = False

    def receive(self, segment):
        if len(segment) == 0:
            self.zeroByteMessage = True
        self.buffer = self.buffer + segment

    def ready(self):
        return "\n" in self.buffer
        
    def closingConnection(self):
        return self.zeroByteMessage

    def getMessage(self):
        message, self.buffer = self.buffer.split("\n", 1)
        return message

class HI2event(object):
    def __init__(self, message):
        try:
            iriMessage = eval(message)
            self.type = iriMessage[2][4][2]
            self.username = iriMessage[2][2].getNID()
            self.nidList = iriMessage[2][5]
            self.locatedMACs = {}
            self.ipOfMACs = {}
        except:
            self.type = "garbage"
    
    def parseNidList(self):
        for nid in self.nidList:
            nidType = {0: "unknown", 1: "unknown"}
            for i in range(0, 2):
                if re.match(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$', nid[i]) != None:
                    nidType[i] = "IPv4"
                elif re.match(r'^(?:[0-9a-f]{2}:){5}[0-9a-f]{2}$', nid[i]) != None:
                    nidType[i] = "MAC"
                elif re.match(r'^sdn-location:', nid[i]) != None:
                    nidType[i] = "location"
                else:
                    nidType[i] = "others"
            
            if nidType == {0: "MAC", 1: "IPv4"}:
                self.ipOfMACs[nid[0]] = nid[1]
            
            elif nidType == {0: "IPv4", 1: "MAC"}:
                self.ipOfMACs[nid[1]] = nid[0]
            
            elif nidType == {0: "MAC", 1: "location"}:
                split = nid[1].split(':', 1)
                #self.locatedMACs.append((nid[0], split[1])) # (mac, location)
                self.locatedMACs[nid[0]] = split[1]
            
            elif nidType == {0: "location", 1: "MAC"}:
                split = nid[0].split(':', 1)
                #self.locatedMACs.append((nid[1], split[1])) # (mac, location)
                self.locatedMACs[nid[1]] = split[1]
    
    def createDiffList(self):
        global db
        result = []
        
        if self.type == "BEGIN":
            userFound = False
            
            for user in db: # dictionary
                if user == self.username:
                    userFound = True
                    for newLocatedMAC in self.locatedMACs: # dictionary
                        entry = (newLocatedMAC, self.locatedMACs[newLocatedMAC], self.ipOfMACs[newLocatedMAC])
                        if entry not in db[user]:
                            db[user].append(entry)
                            result.append(entry)
                    break
            
            if userFound == False:
                db[self.username] = []
                for newLocatedMAC in self.locatedMACs:
                    entry = (newLocatedMAC, self.locatedMACs[newLocatedMAC], self.ipOfMACs[newLocatedMAC])
                    db[self.username].append(entry)
                    result.append(entry)
        
        elif self.type == "END":
            if self.username in db: 
                for oldLocatedMAC in db[self.username]:
                    if oldLocatedMAC[0] not in self.locatedMACs:
                        result.append(oldLocatedMAC)
                        
                db[self.username] = []
                for newLocatedMAC in self.locatedMACs:
                    entry = (newLocatedMAC, self.locatedMACs[newLocatedMAC], self.ipOfMACs[newLocatedMAC])
                    db[self.username].append(entry)
                
                if len(db[self.username]) == 0:
                    del db[self.username]
        
        
        return result

def send_all_data(socket):
    for user in db:
        for locatedMAC in db[user]:
                message = {'type': 'BEGIN', 'username': user, 'mac': locatedMAC[0], 'ip': locatedMAC[2], 'location': locatedMAC[1]}
                jsonMessage = json.dumps(message) + "\n"
                socket.sendall(jsonMessage)
                print jsonMessage

def event_handler(data):
    event = HI2event(data)
    if event.type == "garbage":
        return;
    
    event.parseNidList()
    locatedMACs = event.createDiffList()
    
    for locatedMAC in locatedMACs:
        message = {'type': event.type, 'username': event.username, 'mac': locatedMAC[0], 'ip': locatedMAC[2], 'location': locatedMAC[1]}
        jsonMessage = json.dumps(message) + "\n"
        
        for consumer in consumers:
            consumer.sendall(jsonMessage)
        print jsonMessage

class heartbeatThread(threading.Thread):
    def run(self):
        while heartbeatThreadRunning:
            time.sleep(5)
            threads = threading.enumerate()
            for thread in threads:
                if thread.isAlive() == False:
                    exit(1)
            #for consumer in consumers:
            #    consumer.sendall("HEARTBEAT\n")

def create_hi1_message(mode, nid, version):
    global interceptions
    
    if ":" in nid:
        nidtype, username = nid.split(':', 1)
    else:
        username = nid
    
    liid = username + "-" + str(version)
    message = "('"+mode+"', HI1Intercept('"+liid+"', '"+nid+"', '3', '0', '2147483647', ''))\n"
    print "send to hi1", message
    return message
  
def delete_old_interceptions(hi1Port):
    global interceptions, revisionNumber
    for index in range(0, len(interceptions)):
        if interceptions[index][1] < revisionNumber and interceptions[index][2] != 0:
            hi1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            hi1.connect(("127.0.0.1", hi1Port))
            hi1.sendall(create_hi1_message("delete_intercept", interceptions[index][0], interceptions[index][2]))
            hi1.close()
            interceptions[index] = (interceptions[index][0], interceptions[index][1], 0) # zmazany odposlech
    revisionNumber = revisionNumber + 1

def is_nid_new(nid):
    for index in range(0, len(interceptions)):
        if interceptions[index][0] == nid:
            if interceptions[index][2] == 0: #zmazany odposlech
                interceptions[index] = (interceptions[index][0], revisionNumber, revisionNumber)
                return True
            else:
                interceptions[index] = (interceptions[index][0], revisionNumber, interceptions[index][2]) # update zaznamu
                return False
    
    interceptions.append((nid, revisionNumber, revisionNumber))
    return True

              
def main(argv):
    global consumers, heartbeatThreadRunning
    
    heartbeatThread().start()
    
    hi2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    hi2.connect(("0.0.0.0", int(argv[0])))
    hi2Data = SocketData();
    
    ctrl = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    ctrl.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    ctrl.bind(('', 60001))
    ctrl.listen(5)
    interceptData = SocketData();
    
    sockets = [hi2, ctrl]
    
    done = False
    
    while not done:
        try:
            readable, writable, exceptional = select.select(sockets, [], [])
            for s in readable:
                if s is hi2:
                    hi2Data.receive(s.recv(1024))
                    if hi2Data.closingConnection(): 
                        print "hi2Data closingConnection"
                        done = True
                        break
                    while hi2Data.ready():
                        event_handler(hi2Data.getMessage())
                elif s is ctrl:
                    connection, client_address = s.accept()
                    print "spojeni s kontrolerem", client_address, connection
                    sockets.append(connection)
                    consumers.append(connection)
                else:
                    '''if s is hi2:
                        print "else hi2", s
                    elif s is ctrl:
                        print "else ctrl", s
                    else:
                        print "else consumera", s
                    '''
                    interceptData.receive(s.recv(1024))
                    if interceptData.closingConnection(): 
                        print "remove socket"
                        interceptData.reset()
                        sockets.remove(s)
                        consumers.remove(s)
                        s.close()
                        continue
                    
                    while interceptData.ready():
                        message = interceptData.getMessage()
                        print "message:", message
                        if message == "load_all_data":
                            delete_old_interceptions(int(argv[1]))
                            interceptData.reset()
                            send_all_data(s)
                        elif is_nid_new(message):
                            hi1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                            hi1.connect(("127.0.0.1", int(argv[1])))
                            hi1.sendall(create_hi1_message("new_intercept", message, revisionNumber))
                            hi1.close()
                    
                            
        except KeyboardInterrupt:
            heartbeatThreadRunning = False
            for consumer in consumers:
                consumer.close()
                consumers.remove(consumer)
            break
    
    hi2.close()
    ctrl.close()

if __name__ == "__main__":
    if len(sys.argv) != 3 or ( sys.argv[1] == "-h" or sys.argv[1] == "--help" ):
        print "Skript pre vytvorenie rozhrania medzi SIMSom a kontrolerom SDN. Parametre:"
        print "1. cislo portu pre spojenie so SIMSom HI2" 
        print "2. cislo portu pre spojenie so SIMSom HI1"
        print "Priklad: " + sys.argv[0] + " 21102 21099"
    else:
        main(sys.argv[1:])
