from threading import Lock
import inspect
import csv


from pyretic.examples.shared import *
from pyretic.lib.corelib import *
from pyretic.lib.std import *
from pyretic.lib.query import *
from pyretic.core.runtime import virtual_field

usersFile = "pyretic/examples/configs/users.csv"
resourcesFile = "pyretic/examples/configs/resources.csv"
rulesFile = "pyretic/examples/configs/stats.csv"

class statistics(DynamicPolicy):
    def __init__(self):
        self.debug = 0
        self.network = None
        self.topology = None
        self.lock = Lock()
        self.usernames = {'default': 'default'}
        self.groups = ["default"]
        self.loadUsers()
        self.rules = {}
        self.loadRules()
        self.resources = {}
        self.loadResources()
        self.timer = 10
        self.matchFilter = None
        self.statsUsers = {}
        self.statsGroups = {}
        self.userFile = {}
        self.groupFile = {}
        self.createOutputFiles()
        super(statistics,self).__init__(true)
        self.update_policy()
    
    def createOutputFiles(self):
        for resource in self.rules:
            self.userFile[resource] = open("pyretic/examples/stats/" + resource + "-users.csv", 'w', 0)
            self.userFile[resource].write("Time")
            for group in self.rules[resource]:
                if group == "*":
                    for group2 in self.groups:
                        for user in self.usernames:
                            if self.usernames[user] == group2:
                                self.userFile[resource].write(";" + user)
                else:
                    for user in self.usernames:
                        if self.usernames[user] == group:
                            self.userFile[resource].write(";" + user)
            self.userFile[resource].write("\n")
            
            self.groupFile[resource] = open("pyretic/examples/stats/" + resource + "-groups.csv", 'w', 0)
            self.groupFile[resource].write("Time")
            for group in self.rules[resource]:
                if group == "*":
                    for group2 in self.groups:
                        self.groupFile[resource].write(";" + group2)
                else:
                    self.groupFile[resource].write(";" + group)
            self.groupFile[resource].write("\n")
    
    def proactive_counts(self, name):
        q = counts(self.timer,['srcip'])
        a = lambda counts: self.proactive_counts_printer (counts, name)
        q.register_callback(a)
        return q
    
    def diff_stats(self, resourceName):
        time = str(datetime.now().time().isoformat())
        
        self.userFile[resourceName].write(time)
        for group in self.rules[resourceName]:
            if group == "*":
                for group2 in self.groups:
                    for user in self.usernames:
                        if self.usernames[user] == group2:
                            if user in self.statsUsers[resourceName]:
                                num = int(self.statsUsers[resourceName][user][0]) - int(self.statsUsers[resourceName][user][1])
                                num = 0 if num < 0 else num
                                self.userFile[resourceName].write(";"+str(num))
                                self.statsUsers[resourceName][user] = (self.statsUsers[resourceName][user][0], self.statsUsers[resourceName][user][0])
                            else:
                                self.userFile[resourceName].write(";0")
            else:
                for user in self.usernames:
                    if self.usernames[user] == group:
                        if user in self.statsUsers[resourceName]:
                            num = int(self.statsUsers[resourceName][user][0]) - int(self.statsUsers[resourceName][user][1])
                            num = 0 if num < 0 else num
                            self.userFile[resourceName].write(";"+str(num))
                            self.statsUsers[resourceName][user] = (self.statsUsers[resourceName][user][0], self.statsUsers[resourceName][user][0])
                        else:
                            self.userFile[resourceName].write(";0")
        self.userFile[resourceName].write("\n")
        
        self.groupFile[resourceName].write(time)
        for group in self.rules[resourceName]:
            if group == "*":
                for group2 in self.groups:
                    if group2 in self.statsGroups[resourceName]:
                        num = int(self.statsGroups[resourceName][group2][0]) - int(self.statsGroups[resourceName][group2][1])
                        num = 0 if num < 0 else num
                        self.groupFile[resourceName].write(";"+str(num))
                        self.statsGroups[resourceName][group2] = (self.statsGroups[resourceName][group2][0], self.statsGroups[resourceName][group2][0])
                    else:
                        self.groupFile[resourceName].write(";0")
            else:
                if group in self.statsGroups[resourceName]:
                    num = int(self.statsGroups[resourceName][group][0]) - int(self.statsGroups[resourceName][group][1])
                    num = 0 if num < 0 else num
                    self.groupFile[resourceName].write(";"+str(num))
                    self.statsGroups[resourceName][group] = (self.statsGroups[resourceName][group][0], self.statsGroups[resourceName][group][0])
                else:
                    self.groupFile[resourceName].write(";0")
        self.groupFile[resourceName].write("\n")
    
    def proactive_counts_printer(self, counts, resourceName):
        if counts or True:
            if resourceName not in self.statsUsers:
                self.statsUsers[resourceName] = {}
            if resourceName not in self.statsGroups:
                self.statsGroups[resourceName] = {}
            
            for pred, pkt_byte_count in counts.iteritems():
                ip = str(pred['srcip'])[:-3]
                byteCount = pkt_byte_count[1]
                
                found = False
                for port in portsList:
                    for device in port["devices"]:
                        if device["ip"] == IPAddr(ip):
                            found = True
                            username = device["username"]
                            group = device["group"]
                
                if found == False:
                    username = "default"
                    group = "default"
                
                if username not in self.statsUsers[resourceName]:
                    self.statsUsers[resourceName][username] = (0, 0)
                
                userData = (byteCount, self.statsUsers[resourceName][username][1])
                self.statsUsers[resourceName][username] = userData
                
                if group not in self.statsGroups[resourceName]:
                    self.statsGroups[resourceName][group] = (0, 0)
                
                groupData = (byteCount, self.statsGroups[resourceName][group][1])
                self.statsGroups[resourceName][group] = groupData
                
            self.diff_stats(resourceName)
    
    def addResource(self, resource, address):
        ip = None
        proto = None
        port = None
        
        if ":" in address:
            ip, proto, port = address.split(':')
        else:
            ip = address
        self.resources[resource].append((ip, proto, port))
    
    def loadResources(self):
        with open(resourcesFile, 'rb') as f:
            reader = csv.reader(f, delimiter=';')
            line = 0
            for row in reader:
                line = line + 1
                if line == 1:
                    continue
                
                self.resources[row[0]] = []
                
                if "," in row[1]:
                    for address in row[1].split(","):
                        self.addResource(row[0], address)
                else:
                    self.addResource(row[0], row[1])
    
    def loadUsers(self):
        with open(usersFile, 'rb') as f:
            reader = csv.reader(f, delimiter=';')
            line = 0
            for row in reader:
                line = line + 1
                if line == 1:
                    continue
                
                if row[1] not in self.groups:
                    self.groups.append(row[1])
                
                self.usernames[row[0]] = row[1]
    
    def loadRules(self):
        with open(rulesFile, 'rb') as f:
            reader = csv.reader(f, delimiter=';')
            line = 0
            for row in reader:
                line = line + 1
                if line == 1:
                    continue
                
                self.rules[row[0]] = []
                
                for group in row[1].split(","):
                    self.rules[row[0]].append(group)
    def add_filter(self, match):
        if self.matchFilter == None:
            self.matchFilter = match
        else:
            self.matchFilter = self.matchFilter + match
    
    def update_policy (self):
        self.policy = identity
        
        for resource in self.rules:
            self.matchFilter = None
            for resourceParam in self.resources[resource]: #[('192.168.34.2', 'tcp', '80'), (), ...]
                for port in portsList:
                    for device in port["devices"]:
                        if device["ip"] == IPAddr(resourceParam[0]):
                            if resourceParam[1] == "tcp":
                                l4protocol = 6;
                            elif resourceParam[1] == "udp":
                                l4protocol = 17;
                            else:
                                l4protocol = 1
                            
                            if "*" in self.rules[resource]: #['*']
                                self.add_filter(match(switch=int(port["switch"]),outport=int(port["port"]),dstip=IPAddr(resourceParam[0]),protocol=l4protocol,dstport=int(resourceParam[2]),ethtype=IP_TYPE))
                                self.add_filter(match(switch=int(port["switch"]),inport=int(port["port"]),srcip=IPAddr(resourceParam[0]),protocol=l4protocol,srcport=int(resourceParam[2]),ethtype=IP_TYPE))
                            else:
                                for group in self.rules[resource]:
                                    self.add_filter(match(switch=int(port["switch"]),outport=int(port["port"]),dstip=IPAddr(resourceParam[0]),protocol=l4protocol,dstport=int(resourceParam[2]),group=group,ethtype=IP_TYPE))
                                    self.add_filter(match(switch=int(port["switch"]),inport=int(port["port"]),srcip=IPAddr(resourceParam[0]),protocol=l4protocol,srcport=int(resourceParam[2]),group=group,ethtype=IP_TYPE))
            
            if self.matchFilter != None:
                self.policy = self.policy + ( self.matchFilter >> (self.proactive_counts(resource) + identity) )
        
        if self.debug:
            print "statistics POLICY", self.policy
                    
def main():
    return stats()

