"""
Reintroduce delta and doble delta features
with delta and delta delta I got
with 
N=1 acc=74%
N=2 acc=71%
N=3 acc=68%

Deltas didn't help. I will use version 9 as the best version I got and
use jackknifing and train a gmm on all data train and dev.
"""

import os
os.environ["OMP_NUM_THREADS"] = "1"
from sklearn.mixture import GaussianMixture
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from python_speech_features import mfcc, logfbank, delta # pip install python_speech_features
 
from sklearn.mixture import GaussianMixture
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from python_speech_features import mfcc, logfbank
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

import librosa # pip install librosa

from sklearn.preprocessing import StandardScaler

TRAIN_DIR = "./Separate_data/train/sounds"
DEV_DIR = "./Separate_data/dev/sounds"
NUM_CEPS = 20
NFFT = 1024
SILENCE_TOP_DB = 25
DELTAS_N = 3

def load_audio(src_folder):

    all_classes_mfcc_feats = []
    class_labels = []
    for person_class in sorted(os.listdir(src_folder)):
        class_dir = os.path.join(src_folder, person_class)
        
        class_mfcc_feats = []
        class_labels.append(person_class)

        for audio_record in sorted(os.listdir(class_dir)):
            audio_record_pth = os.path.join(class_dir, audio_record)
            freq_sampling, audio_sig = wavfile.read(audio_record_pth)

            audio_sig = audio_sig[20000:] # cut first 2 seconds

            mfcc_feats = mfcc(audio_sig, freq_sampling, numcep=NUM_CEPS, appendEnergy=False) # extract mfcc
            
            class_mfcc_feats.append(mfcc_feats) # add them to all mfcc for this class
        
        all_classes_mfcc_feats.append(class_mfcc_feats)
    return all_classes_mfcc_feats, class_labels

def remove_silence(audio_sig, top_db=SILENCE_TOP_DB):
    intervals = librosa.effects.split(audio_sig, top_db=top_db)
    non_silent = [audio_sig[start:end] for start, end in intervals]
    return np.concatenate(non_silent)

def get_feats(audio_sig, freq_sampling):
    audio_sig_no_silence = remove_silence(audio_sig)
    mfcc_feats = mfcc(audio_sig_no_silence, freq_sampling, numcep=NUM_CEPS, appendEnergy=False, nfft=NFFT) # extract mfcc
    mfcc_feats -= np.mean(mfcc_feats, axis=0, keepdims=True) # removes the mean of each MFCC coefficient column-wise, flattening variations due to channel/mic
    delta_feats = delta(mfcc_feats, DELTAS_N)
    delta_delta_feats = delta(delta_feats, DELTAS_N)
    combined_feats = np.hstack((mfcc_feats, delta_feats, delta_delta_feats))
    return combined_feats

def compute_global_scaler(src_folder):
    all_feats = []
    for person_class in sorted(os.listdir(src_folder)):
        class_dir = os.path.join(src_folder, person_class)
        for audio_record in sorted(os.listdir(class_dir)):
            audio_record_pth = os.path.join(class_dir, audio_record)
            freq_sampling, audio_sig = wavfile.read(audio_record_pth)
            audio_sig = audio_sig[20000:]
            combined_feats = get_feats(audio_sig, freq_sampling)
            all_feats.append(combined_feats)
    all_feats = np.vstack(all_feats)
    scaler = StandardScaler()
    scaler.fit(all_feats)
    return scaler

def train_gmm(src_folder, scaler):

    class_features = {} # store features per class

    for person_class in sorted(os.listdir(src_folder)):
        class_dir = os.path.join(src_folder, person_class)
        
        class_mfcc_feats = []

        for audio_record in sorted(os.listdir(class_dir)):
            audio_record_pth = os.path.join(class_dir, audio_record)
            freq_sampling, audio_sig = wavfile.read(audio_record_pth)

            audio_sig = audio_sig[20000:] # cut first 2 seconds

            combined_feats = get_feats(audio_sig, freq_sampling)
            # add simple augumentation
            #stretched = librosa.effects.time_stretch(audio_sig, rate=0.9)
            #mfcc_feats = np.vstack([mfcc_feats, mfcc(stretched[20000:], freq_sampling)]) #, nfft=2048)])

            class_mfcc_feats.append(combined_feats) # add them to all mfcc for this class

        if class_mfcc_feats:
            class_mfcc_feats = np.vstack(class_mfcc_feats)
            print(f"\nTotal frames for {person_class}: {class_mfcc_feats.shape[0]}")
            
            # Normalize features
            normalized_feats = scaler.transform(class_mfcc_feats)
            
            # train GMM for this class
            gmm = GaussianMixture(n_components=8, 
                                covariance_type='full',
                                max_iter=500,
                                n_init=3,
                                random_state=42)
            gmm.fit(normalized_feats)
            
            # Store GMM and its scaler
            class_features[person_class] = (gmm, scaler)

    return class_features

# train GMM
#trained_models = load_audio_train_gmm(TRAIN_DIR)

# To use for classification:
#test_features = mfcc(test_audio[20000:], 16000)
#scores = {name: model.score(test_features) for name, model in trained_models.items()}
#predicted_class = max(scores.items(), key=lambda x: x[1])[0]

def evaluate_models(models, test_dir, scaler):
    """Evaluate trained GMM models on test data"""
    true_labels = []
    pred_labels = []
    
    for person_class in sorted(os.listdir(test_dir)):
        class_dir = os.path.join(test_dir, person_class)
        
        for audio_record in sorted(os.listdir(class_dir)):
            try:
                audio_record_pth = os.path.join(class_dir, audio_record)
                freq_sampling, audio_sig = wavfile.read(audio_record_pth)
                
                audio_sig = audio_sig[20000:] if len(audio_sig) > 20000 else audio_sig
                test_features = get_feats(audio_sig, freq_sampling)
                
                # Score against all models
                scores = {
                    name: gmm.score(scaler.transform(test_features))
                    for name, (gmm, scaler) in models.items()
                }
                predicted_class = max(scores.items(), key=lambda x: x[1])[0]
                
                true_labels.append(person_class)
                pred_labels.append(predicted_class)
                
            except Exception as e:
                print(f"Error processing {audio_record_pth}: {str(e)}")
                continue
    
    return true_labels, pred_labels

def plot_confusion_matrix(true_labels, pred_labels, classes):
    """Plot confusion matrix"""
    cm = confusion_matrix(true_labels, pred_labels, labels=classes)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

# Main execution
if __name__ == "__main__":

    # Prepare scaler
    print("Computing global scaler...")
    scaler = compute_global_scaler(TRAIN_DIR)

    # Train models
    print("Training GMM models...")
    trained_models = train_gmm(TRAIN_DIR, scaler)
    
    # Evaluate on development set
    print("\nEvaluating on development set...")
    true_labels, pred_labels = evaluate_models(trained_models, DEV_DIR, scaler)
    
    # Get unique class names
    classes = sorted(list(trained_models.keys()))
    
    # Print classification report
    print("\nClassification Report:")
    print(classification_report(true_labels, pred_labels, target_names=classes, zero_division=0))
    
    # Plot confusion matrix
    plot_confusion_matrix(true_labels, pred_labels, classes)
    plt.show()
    
    # Example of classifying a single file
    # test_file = "path_to_test_file.wav"
    # freq_sampling, audio_sig = wavfile.read(test_file)
    # test_features = mfcc(audio_sig[20000:], freq_sampling)
    # scores = {name: model.score(test_features) for name, model in trained_models.items()}
    # predicted_class = max(scores.items(), key=lambda x: x[1])[0]
    # print(f"\nPredicted class: {predicted_class}")