import cv2
import numpy as np
import os
import tensorflow as tf
import sys
import csv
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split

"""Module for training convolutional neural network for regression of the rotation angle 
   Usage:
   > python TrainRotationCNN.py <rotationDataSet1> <rotationDataSet2> <boundingBox> <outputCNNFile>
   
   Input arguments:
   * <rotationDataSet1> - First part of the rotation data set
   * <rotationDataSet2> - Second part of the rotation data set
   * <boundingBox> - Bounding box of the rotationg product
   * <outputCNNFile> - File name for saving the trained network

"""

__author__ = "Ondrej Klima"
__copyright__ = "Copyright 2020"
__credits__ = ["Ondrej Klima"]
__email__ = "iklima@fit.vutbr.cz"
__license__ = "BUT"
__version__ = "1.0"
__maintainer__ = "Ondrej Klima"

#Experimenting with custom loss function
def my_loss_fn(y_true, y_pred):
  diff = tf.abs(y_true - y_pred);
  squared_difference = tf.square(tf.math.minimum(diff, 1 - diff))
  return tf.reduce_mean(squared_difference, axis=-1)  # Note the `axis=-1`

def main():
    # Parsing input arguments 
    argv = sys.argv

    try:
        dataSet1Path = argv[1]
    except IndexError:
        raise IndexError('Directory path containg the first part of data set must be supplied as an argument')
    
    try:
        dataSet2Path = argv[2]
    except IndexError:
        raise IndexError('Directory path containg the second part of data set must be supplied as an argument')
  
    try:
        boundigBoxFileName = argv[3]
    except IndexError:
        raise IndexError('Bounding box file name must be supplied as an argument')
    if not path.isfile(boundigBoxFileName):
        raise ValueError('File "%s" does not exist!' % boundigBoxFileName)    
        
    try:
        saveFileName = argv[4]
    except IndexError:
        raise IndexError('Output file name for the CNN must be supplied as an argument')             

    angles = range(0,360)
    
    with open(boundigBoxFileName, newline='') as csvfile:
        reader = csv.reader(csvfile, delimiter=',', quotechar='|')
        dataBB = list(reader)
    
    images = []
    images2 = []
    for i in angles:
      img = cv2.imread(dataSet1Path + '/' + str(i) + ".png")
      images.append(cv2.resize(img[int(dataBB[0][0]):int(dataBB[0][1]), int(dataBB[0][2]):int(dataBB[0][3])],(256,256)))   
      img = cv2.imread(dataSet2Path + '/' + str(i) + ".png")
      images2.append(cv2.resize(img[int(dataBB[0][0]):int(dataBB[0][1]), int(dataBB[0][2]):int(dataBB[0][3])],(256,256)))
    
    train_imgs = np.array(images + images2) / 255.
    #Free the memory
    images=[]
    images2=[]
    
    stdsin = []
    stdsin2 = []    
    for i in angles:
      stdsin.append(i / 360.)
      stdsin2.append((i + 0.5) / 360.)
      
    train_labels = np.array(stdsin + stdsin2);
    (y1, y2, x1, x2) = train_test_split(train_labels, train_imgs, test_size=385, random_state=37)

    chanDim = -1 
    # define the model input
    inputs = Input(shape=(256, 256, 3))
    x = inputs
     
    x = Conv2D(16, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
 
    x = Conv2D(32, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    x = Conv2D(64, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(128, (3, 3), padding="same")(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
  
    x = Flatten()(x)
    x = Dense(360)(x)
    x = Activation("relu")(x)
    x = BatchNormalization(axis=chanDim)(x)
    x = Dropout(0.5)(x)
  
    x = Dense(4)(x)
    x = Activation("relu")(x)

    x = Dense(1, activation="linear")(x)
    model = Model(inputs, x)
    opt = Adam(lr=1e-3, decay=1e-3 / 200) 
    model.compile(loss="mean_squared_error", optimizer=opt)

    model.fit(x=x1, y=y1, validation_data=(x2, y2), epochs=400, batch_size=128)
    model.fit(x=x2, y=y2, validation_data=(x1, y1), epochs=400, batch_size=128)
    model.fit(x=x1, y=y1, validation_data=(x2, y2), epochs=400, batch_size=128)
    model.fit(x=x2, y=y2, validation_data=(x1, y1), epochs=400, batch_size=128)
    
    model.save(saveFileName)

if __name__ == "__main__":
    main()