# Copyright (C) 2018 Libor Polčák <ipolcak@fit.vutbr.cz>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import copy
import json
import time

import dns
from time_parser import TimeWrapper, TorTimeWrapper

class onion_router():
    """ This class represents information about a single onion router at a specific consensus. """
    request_additional = False

    merge_attrs = ["_onion_router__" + attr for attr in ["nickname", "identity", \
            "digest", "publication", "ip", "orport", "dirport", "ipv6", "flags", \
            "version", "supported_proto", "bandwidth", "allow_ports"]]

    def __init__(self, lines):
        self.__all_lines = copy.copy(lines)
        # Provide some default values
        self.__ipv6 = []
        self.__allow_ports = "Unknown"
        for line in lines:
            words = line.split()
            try:
                parser = getattr(self, "parse_" + words[0].replace("-", "_"))
            except:
                raise NotImplementedError("onion_router received line " + line)
            parser(*words[1:])
        if self.request_additional:
            self.__reverse_name = dns.ip_addr_to_domain(self.get_ip())
            if self.__reverse_name:
                self.__all_lines.append("x-dns-reverse %s %s\n" % (self.__reverse_name,
                    str(TimeWrapper(time.time()))))

    def append_lines(self, write_file):
        write_file.writelines(self.__all_lines)

    def get_json_dict(self):
        if not self.__ipv6:
            del self.__ipv6 # Remove the default value which is necessary only during parsing
        return {k[k.find("__")+2:]:v for (k, v) in self.__dict__.items() if
                k.find("__") and (k.find("__all_lines") == -1) and (k.find("unixtime") == -1)}

    def output_json(self, write_file):
        write_file.writelines(
                 json.dumps(self.get_json_dict(), sort_keys=True, indent=4))
        write_file.write("\n")

    def get_nickname(self):
        return self.__nickname

    def get_identity(self):
        return self.__identity

    def get_digest(self):
        return self.__digest

    def get_publication(self):
        return self.__publication

    def get_ip(self):
        return self.__ip

    def get_orport(self):
        return self.__orport

    def get_dirport(self):
        return self.__dirport

    def parse_r(self, *args):
        # "r" SP nickname SP identity SP digest SP publication SP IP SP ORPort SP DirPort NL
        # first, exactly once
        self.__nickname = args[0]
        self.__identity = args[1]
        self.__digest = args[2]
        self.__publication = (args[3], args[4]) # (date, time)
        self.__ip = args[5]
        self.__orport = args[6]
        self.__dirport = args[7]

    def get_ipv6(self):
        return self.__ipv6

    def parse_a(self, *args):
        # "a" SP address ":" port NL
        # any number
        addr, port = args[0][1:].split("]")
        self.__ipv6.append((addr, port[1:])) # Remove : at the port beginning

    def get_flags(self):
        return self.__flags

    def parse_s(self, *args):
        # "s" SP Flags NL
        # exactly once
        # ASCII strings
        # "Authority", "BadExit", "Exit", "Fast", "Guard", "HSDir",
        # "NoEdConsensus", "Stable", "Running", "Valid", "V2Dir"
        self.__flags = args

    def get_version(self):
        return self.__version

    def parse_v(self, *args):
        # "v" SP version NL
        # Optional, at most once
        self.__version = " ".join(args)

    def get_supported_proto(self):
        return self.__supported_proto

    def parse_pr(self, *args):
        # "pr" SP Entries NL
        # proto family element, necessary after MAY 2018
        self.__supported_proto = args

    def get_bandwidth(self):
        return self.__bandwidth

    def parse_w(self, *args):
        # "w" SP "Bandwidth=" INT [SP "Measured=" INT] [SP "Unmeasured=1"] NL
        # at most once
        # in currently kBps
        self.__bandwidth = {}
        for kv in args:
            key, value = kv.split("=")
            self.__bandwidth[key] = value

    def get_allow_ports(self):
        return self.__allow_ports

    def parse_p(self, *args):
        # "p" SP ("accept" / "reject") SP PortList NL
        # At most once
        # Note that "p" is not present in archive files
        self.__allow_ports = " ".join(args)

    def parse_m(self, *args):
        # "m" SP methods 1*(SP algorithm "=" digest) NL
        pass

    def parse_id(self, *args):
        # "id" SP "ed25519" SP ed25519-identity NL
        # "id" SP "ed25519" SP "none" NL
        pass

    def get_inconsensus_val_after(self):
        return self.__inconsensus_val_after

    def get_inconsensus_val_after_unixtime(self):
        try:
            return self.__inconsensus_val_after_unixtime
        except AttributeError: # Lazy evaluation
            self.__inconsensus_val_after_unixtime = TorTimeWrapper(*self.__inconsensus_val_after).get()
            return self.__inconsensus_val_after_unixtime

    def get_inconsensus_fresh_until(self):
        return self.__inconsensus_fresh_until

    def get_inconsensus_fresh_until_unixtime(self):
        try:
            return self.__inconsensus_fresh_until_unixtime
        except AttributeError: # Lazy evaluation
            self.__inconsensus_fresh_until_unixtime = TorTimeWrapper(*self.__inconsensus_fresh_until).get()
            return self.__inconsensus_fresh_until_unixtime

    def get_inconsensus_val_until(self):
        return self.__inconsensus_val_until

    def get_inconsensus_val_until_unixtime(self):
        try:
            return self.__inconsensus_val_until_unixtime
        except AttributeError: # Lazy evaluation
            self.__inconsensus_val_until_unixtime = TorTimeWrapper(*self.__inconsensus_val_until).get()
            return self.__inconsensus_val_until_unixtime

    def add_network_status_consensus(self, val_after, fresh_until, val_until):
        self.__inconsensus_val_after = val_after
        self.__inconsensus_fresh_until = fresh_until
        self.__inconsensus_val_until = val_until
        self.__all_lines.append("x-inconsensus-valid-after %s %s\n" % val_after)
        self.__all_lines.append("x-inconsensus-fresh-until %s %s\n" % fresh_until)
        self.__all_lines.append("x-inconsensus-valid-until %s %s\n" % val_until)

    def parse_x_inconsensus_valid_after(self, *args):
        self.__inconsensus_val_after = args

    def parse_x_inconsensus_fresh_until(self, *args):
        self.__inconsensus_fresh_until = args

    def parse_x_inconsensus_valid_until(self, *args):
        self.__inconsensus_val_until = args

    def get_dns_reverse(self):
        return self.__reverse_name["dns"], self.__reverse_name["queried"]

    def parse_x_dns_reverse(self, *args):
        self.__reverse_name = {
                "dns": args[0],
                "queried": [(args[1], args[2])]
            }

    def merge(self, other):
        """ Merges information from other to self if allowed.

        Only entries with the same attributes can be merged, so only inconsensus and DNS reverse
        query time may differ. Returns true if the original was updated,
        """
        for attr in self.merge_attrs:
            try:
                if getattr(self, attr) != getattr(other, attr):
                    return False
            except AttributeError:
                if hasattr(self, attr) == hasattr(other, attr):
                    continue
                else:
                    return False
            except:
                return False
        if hasattr(self, "_onion_router__reverse_name"):
            if hasattr(other, "_onion_router__reverse_name"):
                if self.__reverse_name["dns"] != other.__reverse_name["dns"]:
                    return False
            # Merge the two entries also if one has the DNS attr and the other does not
        left_after = self.get_inconsensus_val_after_unixtime()
        left_fresh = self.get_inconsensus_fresh_until_unixtime()
        left_until = self.get_inconsensus_val_until_unixtime()
        right_after = other.get_inconsensus_val_after_unixtime()
        right_fresh = other.get_inconsensus_fresh_until_unixtime()
        right_until = other.get_inconsensus_val_until_unixtime()
        if not (right_after <= left_until <= right_until or left_after <= right_until <= left_until):
            return False
        if right_after < left_after:
            self.__inconsensus_val_after = other.__inconsensus_val_after
            self.__inconsensus_val_after_unixtime = other.get_inconsensus_val_after_unixtime()
        if right_fresh > left_fresh:
            self.__inconsensus_fresh_until = other.__inconsensus_fresh_until
            self.__inconsensus_fresh_until_unixtime = other.get_inconsensus_fresh_until_unixtime()
        if right_until > left_until:
            self.__inconsensus_val_until = other.__inconsensus_val_until
            self.__inconsensus_val_until_unixtime = other.get_inconsensus_val_until_unixtime()
        if hasattr(self, "_onion_router__reverse_name") and hasattr(other, "_onion_router__reverse_name"):
            self.__reverse_name["queried"].extend(other.__reverse_name["queried"])
        elif hasattr(other, "_onion_router__reverse_name"):
            self.__reverse_name = other.__reverse_name
        return True

    def get_maxmind_geolocation(self):
        return self.__geolite_maxmind_geolication

    def get_maxmind_asn(self):
        return self.__geolite_maxmind_as

    def add_geolite_data(self, data, validity, geolite_type):
        transl = {"City": "maxmind_geolication", "ASN": "maxmind_as"}
        ts = time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(validity))
        dict_name = "_onion_router__geolite_" + transl[geolite_type]
        geolite_data = [(k,v) for k, v in data.items()] + [("timestamp", ts)]
        try:
            getattr(self, dict_name)[ts] = geolite_data
        except:
            setattr(self, dict_name, {ts: geolite_data})

    def append_geolite_data(self, ipaddr, geolite2_access):
        start = self.get_inconsensus_val_after_unixtime()
        end = self.get_inconsensus_val_until_unixtime()
        for db in ["City", "ASN"]:
            for validity, data in geolite2_access.get_data(ipaddr, db, start, end):
                self.add_geolite_data(data, validity, db)
