#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Speaker recognition system for SUR project
Part (b) - Audio-based person identification
"""

import os
import numpy as np
import glob
import pickle
import argparse
import librosa
import soundfile as sf
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.mixture import GaussianMixture
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt

# Function to extract MFCC features from WAV files
def wav2mfcc(filename, n_mfcc=13, n_fft=512, hop_length=256):
    """
    Extract MFCC features from a WAV file
    
    Args:
        filename: Path to the WAV file
        n_mfcc: Number of MFCC coefficients
        n_fft: FFT window size
        hop_length: Hop length for the frames
        
    Returns:
        MFCC features
    """
    try:
        # Load audio file (resampling to 22050 Hz if needed)
        y, sr = librosa.load(filename, sr=22050, mono=True)
        
        # Extract MFCCs
        mfccs = librosa.feature.mfcc(
            y=y, 
            sr=sr, 
            n_mfcc=n_mfcc, 
            n_fft=n_fft, 
            hop_length=hop_length
        )
        
        # Transpose to get time as first dimension
        mfccs = mfccs.T
        
        # Calculate delta and delta-delta (velocity and acceleration)
        delta_mfccs = librosa.feature.delta(mfccs, axis=0)
        delta2_mfccs = librosa.feature.delta(mfccs, order=2, axis=0)
        
        # Concatenate MFCC, delta and delta-delta
        features = np.concatenate([mfccs, delta_mfccs, delta2_mfccs], axis=1)
        
        return features
    except Exception as e:
        print(f"Error processing {filename}: {e}")
        return None

def load_data(data_dir="SUR_projekt2024-2025", subset="train", use_dev_for_training=False, dev_ratio=0.5):
    """
    Load WAV files from data directory and extract features
    
    Args:
        data_dir: Base directory containing the data
        subset: 'train' or 'dev' subset
        use_dev_for_training: Whether to use a portion of dev data for training
        dev_ratio: Portion of dev data to use for training (0.0-1.0)
        
    Returns:
        features: List of feature matrices (one per file)
        labels: Class labels
        filenames: Original filenames
    """
    features = []
    labels = []
    filenames = []
    
    # Path to the subset directory
    subset_dir = os.path.join(data_dir, subset)
    
    # Check if directory exists
    if not os.path.exists(subset_dir):
        print(f"Directory {subset_dir} not found. Checking alternative path...")
        # Try to check if data_dir already includes the subset
        if os.path.basename(data_dir) == subset:
            subset_dir = data_dir
        else:
            # Check if the directory structure might be different
            parent_dir = os.path.dirname(data_dir)
            alt_dir = os.path.join(parent_dir, subset)
            if os.path.exists(alt_dir):
                subset_dir = alt_dir
            else:
                raise FileNotFoundError(f"Cannot find data directory. Tried {subset_dir} and {alt_dir}")
    
    print(f"Using data directory: {subset_dir}")
    
    # Loop through each class directory (1-31)
    for class_id in range(1, 32):
        class_dir = os.path.join(subset_dir, str(class_id))
        
        # Skip if class directory doesn't exist
        if not os.path.exists(class_dir):
            print(f"Warning: Class directory {class_dir} not found")
            continue
            
        # Get all WAV files in the class directory
        wav_files = glob.glob(os.path.join(class_dir, "*.wav"))
        
        print(f"Processing class {class_id}, found {len(wav_files)} WAV files")
        
        # Process each WAV file
        for wav_file in wav_files:
            # Extract features
            feature_matrix = wav2mfcc(wav_file)
            
            if feature_matrix is not None and len(feature_matrix) > 0:
                features.append(feature_matrix)
                labels.append(class_id)
                filenames.append(os.path.basename(wav_file))
    
    # Add a portion of dev data if requested and if we're loading train data
    if use_dev_for_training and subset == "train":
        dev_dir = os.path.join(data_dir, "dev")
        
        if os.path.exists(dev_dir):
            print(f"Adding {dev_ratio*100:.0f}% of dev data to training...")
            
            for class_id in range(1, 32):
                class_dir = os.path.join(dev_dir, str(class_id))
                
                if not os.path.exists(class_dir):
                    continue
                
                # Get all WAV files in the dev class directory
                wav_files = glob.glob(os.path.join(class_dir, "*.wav"))
                
                # Calculate how many files to use for training
                num_train_files = max(1, int(len(wav_files) * dev_ratio))
                
                # Select files for training (we'll use the first portion)
                train_files = wav_files[:num_train_files]
                
                print(f"Adding {len(train_files)}/{len(wav_files)} dev files for class {class_id}")
                
                # Process each training file from dev
                for wav_file in train_files:
                    feature_matrix = wav2mfcc(wav_file)
                    
                    if feature_matrix is not None and len(feature_matrix) > 0:
                        features.append(feature_matrix)
                        labels.append(class_id)
                        filenames.append("dev_" + os.path.basename(wav_file))
        else:
            print(f"Dev directory {dev_dir} not found, using only train data.")
    
    return features, np.array(labels), filenames

def train_gmm_models(features, labels, n_components_range=[4, 8, 16, 32]):
    """
    Train GMM models for each class
    
    Args:
        features: List of feature matrices
        labels: Class labels
        n_components_range: Range of GMM components to try
        
    Returns:
        Dictionary of GMM models for each class, best number of components
    """
    # Group features by class
    class_features = {}
    for i, label in enumerate(labels):
        if label not in class_features:
            class_features[label] = []
        class_features[label].append(features[i])
    
    # Find best number of components using cross-validation
    best_n_components = find_best_n_components(class_features, n_components_range)
    print(f"Best number of GMM components: {best_n_components}")
    
    # Train final GMM models for each class
    gmm_models = {}
    for class_id, class_feat_list in class_features.items():
        # Concatenate all feature matrices for this class
        all_features = np.vstack(class_feat_list)
        
        print(f"Training GMM for class {class_id} with {len(all_features)} frames")
        
        # Initialize and train GMM with increased regularization
        gmm = GaussianMixture(
            n_components=best_n_components, 
            covariance_type='diag',
            max_iter=200,
            random_state=42,
            reg_covar=1e-2  # Increased regularization to prevent singular covariance matrices
        )
        gmm.fit(all_features)
        
        # Store model
        gmm_models[class_id] = gmm
    
    return gmm_models, best_n_components

def find_best_n_components(class_features, n_components_range):
    """
    Find the best number of GMM components using cross-validation
    
    Args:
        class_features: Dictionary of feature matrices for each class
        n_components_range: Range of GMM components to try
        
    Returns:
        Best number of components
    """
    scores = []
    
    # Sample a few classes for validation to speed up the process
    sample_classes = np.random.choice(list(class_features.keys()), 
                                      min(5, len(class_features)), 
                                      replace=False)
    
    for n_comp in n_components_range:
        print(f"Evaluating GMM with {n_comp} components...")
        class_scores = []
        
        for class_id in sample_classes:
            # Get features for this class
            class_feat_list = class_features[class_id]
            
            if len(class_feat_list) >= 3:  # Need at least 3 recordings for train/val split
                # Split into train/validation sets
                train_idx, val_idx = train_test_split(
                    range(len(class_feat_list)), 
                    test_size=0.3, 
                    random_state=42
                )
                
                train_feats = [class_feat_list[i] for i in train_idx]
                val_feats = [class_feat_list[i] for i in val_idx]
                
                # Concatenate training features
                train_all = np.vstack(train_feats)
                
                # Train GMM with increased regularization
                gmm = GaussianMixture(
                    n_components=n_comp, 
                    covariance_type='diag',
                    max_iter=200,
                    random_state=42,
                    reg_covar=1e-2  # Increased regularization to prevent singular covariance matrices
                )
                
                try:
                    gmm.fit(train_all)
                    
                    # Calculate log-likelihood on validation set
                    val_scores = []
                    for val_feat in val_feats:
                        val_scores.append(gmm.score(val_feat))
                    
                    # Average validation score
                    avg_score = np.mean(val_scores)
                    class_scores.append(avg_score)
                except Exception as e:
                    print(f"Warning: Failed to train GMM with {n_comp} components for class {class_id}: {e}")
        
        # Average score across classes
        if class_scores:
            avg_score = np.mean(class_scores)
            scores.append((n_comp, avg_score))
            print(f"  Average log-likelihood: {avg_score:.4f}")
        else:
            print(f"  Failed to get scores for {n_comp} components")
    
    # Get best number of components
    if scores:
        scores.sort(key=lambda x: x[1], reverse=True)
        best_n_components = scores[0][0]
    else:
        # Default if validation couldn't be performed
        best_n_components = 16
    
    return best_n_components

def predict_with_gmm(gmm_models, features_list):
    """
    Make predictions using GMM models
    
    Args:
        gmm_models: Dictionary of GMM models for each class
        features_list: List of feature matrices
        
    Returns:
        Predictions, log probabilities
    """
    predictions = []
    log_probs_list = []
    
    # Loop through each feature matrix
    for features in features_list:
        # Calculate log-likelihood for each class
        log_probs = np.zeros(31)
        
        for i, class_id in enumerate(range(1, 32)):
            if class_id in gmm_models:
                # Get log-likelihood from GMM
                log_prob = gmm_models[class_id].score(features)
                log_probs[i] = log_prob
            else:
                log_probs[i] = -np.inf
        
        # Make prediction (class with highest log-likelihood)
        pred = np.argmax(log_probs) + 1  # +1 because classes are 1-indexed
        
        predictions.append(pred)
        log_probs_list.append(log_probs)
    
    return np.array(predictions), np.array(log_probs_list)

def process_test_files(gmm_models, test_dir):
    """
    Process test files and generate predictions
    
    Args:
        gmm_models: Dictionary of GMM models for each class
        test_dir: Directory containing test files
        
    Returns:
        results: List of results (filename, prediction, log_probs)
    """
    results = []
    
    # Get all WAV files in test directory (recursive)
    wav_files = glob.glob(os.path.join(test_dir, "**", "*.wav"), recursive=True)
    
    print(f"Found {len(wav_files)} WAV files for testing")
    
    for wav_file in wav_files:
        # Extract features
        feature_matrix = wav2mfcc(wav_file)
        
        if feature_matrix is not None and len(feature_matrix) > 0:
            # Get predictions
            pred, log_probs = predict_with_gmm(gmm_models, [feature_matrix])
            
            # Get filename without extension
            filename = os.path.basename(wav_file).rsplit('.', 1)[0]
            
            # Store result
            results.append((filename, pred[0], log_probs[0]))
    
    return results

def save_results(results, output_file="audio_gmm_results.txt"):
    """
    Save results to output file
    
    Args:
        results: List of results (filename, prediction, log_probs)
        output_file: Output filename
    """
    with open(output_file, 'w') as f:
        for filename, pred, log_probs in results:
            # Format: filename prediction log_prob_1 log_prob_2 ... log_prob_31
            line = f"{filename} {pred}"
            for lp in log_probs:
                line += f" {lp:.6f}"
            f.write(line + "\n")
    
    print(f"Results saved to {output_file}")

def save_models(gmm_models, n_components, output_file="audio_gmm_models.pkl"):
    """
    Save GMM models to file
    
    Args:
        gmm_models: Dictionary of GMM models for each class
        n_components: Number of GMM components
        output_file: Output filename
    """
    with open(output_file, 'wb') as f:
        pickle.dump((gmm_models, n_components), f)
    
    print(f"Models saved to {output_file}")

def load_models(input_file="audio_gmm_models.pkl"):
    """
    Load GMM models from file
    
    Args:
        input_file: Input filename
        
    Returns:
        gmm_models, n_components
    """
    with open(input_file, 'rb') as f:
        gmm_models, n_components = pickle.load(f)
    
    return gmm_models, n_components

def main():
    """Main function"""
    parser = argparse.ArgumentParser(description="Speaker recognition system")
    parser.add_argument("--data_dir", type=str, default="SUR_projekt2024-2025",
                        help="Base directory containing the data")
    parser.add_argument("--mode", type=str, choices=["train", "test", "train_test"],
                        default="train_test", help="Mode of operation")
    parser.add_argument("--model_file", type=str, default="audio_gmm_models.pkl",
                        help="File to save/load models")
    parser.add_argument("--test_dir", type=str, default=None,
                        help="Directory containing test files")
    parser.add_argument("--output_file", type=str, default="audio_gmm_results.txt",
                        help="Output file for test results")
    parser.add_argument("--n_mfcc", type=int, default=13,
                        help="Number of MFCC coefficients")
    parser.add_argument("--use_dev_for_training", action="store_true",
                        help="Use a portion of dev data for training")
    parser.add_argument("--dev_ratio", type=float, default=0.5,
                        help="Portion of dev data to use for training (0.0-1.0)")
    parser.add_argument("--remaining_dev_file", type=str, default="remaining_dev_audio.txt",
                        help="File to save remaining dev files not used for training")
    
    args = parser.parse_args()
    
    # Set default test directory if not specified
    if args.test_dir is None:
        args.test_dir = os.path.join(args.data_dir, "dev")
    
    if args.mode in ["train", "train_test"]:
        # Load data
        print("Loading training data...")
        features, labels, filenames = load_data(args.data_dir, "train", 
                                              use_dev_for_training=args.use_dev_for_training,
                                              dev_ratio=args.dev_ratio)
        
        # If we're using dev data for training, create a file with remaining dev files for validation
        if args.use_dev_for_training:
            create_remaining_dev_files(args.data_dir, args.dev_ratio, args.remaining_dev_file)
        
        # Train GMM models
        print("Training GMM models...")
        gmm_models, n_components = train_gmm_models(features, labels)
        
        # Save models
        save_models(gmm_models, n_components, args.model_file)
    
    if args.mode in ["test", "train_test"]:
        if args.mode == "test":
            # Load models
            print("Loading GMM models...")
            gmm_models, n_components = load_models(args.model_file)
        
        # Process test files
        print("Processing test files...")
        # If we're using part of dev for training and want to test on the remaining dev files
        if args.use_dev_for_training and os.path.exists(args.remaining_dev_file):
            print(f"Testing on remaining dev files listed in {args.remaining_dev_file}")
            results = process_remaining_dev_files(gmm_models, args.remaining_dev_file, args.data_dir)
        else:
            results = process_test_files(gmm_models, args.test_dir)
        
        # Save results
        save_results(results, args.output_file)

def create_remaining_dev_files(data_dir, dev_ratio, output_file):
    """
    Create a file listing the remaining dev files not used for training
    
    Args:
        data_dir: Base directory containing the data
        dev_ratio: Portion of dev data used for training
        output_file: File to save the list of remaining dev files
    """
    print(f"Creating list of remaining dev files for validation...")
    dev_dir = os.path.join(data_dir, "dev")
    remaining_files = []
    
    for class_id in range(1, 32):
        class_dir = os.path.join(dev_dir, str(class_id))
        if not os.path.exists(class_dir):
            continue
        
        # Get all WAV files in the dev class directory
        wav_files = glob.glob(os.path.join(class_dir, "*.wav"))
        
        # Calculate how many files were used for training
        num_train_files = max(1, int(len(wav_files) * dev_ratio))
        
        # Select files not used for training
        test_files = wav_files[num_train_files:]
        
        for test_file in test_files:
            # Store class and file path
            remaining_files.append((class_id, test_file))
    
    # Save the list of remaining files
    with open(output_file, 'w') as f:
        for class_id, file_path in remaining_files:
            f.write(f"{class_id},{file_path}\n")
    
    print(f"Saved list of {len(remaining_files)} remaining dev files to {output_file}")

def process_remaining_dev_files(gmm_models, file_list, data_dir):
    """
    Process remaining dev files for testing
    
    Args:
        gmm_models: Dictionary of GMM models for each class
        file_list: File containing list of remaining dev files
        data_dir: Base directory containing the data
        
    Returns:
        List of (filename, prediction, log_probs) tuples
    """
    results = []
    
    # Read the list of remaining dev files
    with open(file_list, 'r') as f:
        remaining_files = [line.strip().split(',') for line in f]
    
    print(f"Found {len(remaining_files)} files for testing")
    
    # Process each file
    for class_id_str, file_path in remaining_files:
        # Extract features
        features = wav2mfcc(file_path)
        
        if features is not None and len(features) > 0:
            # Make predictions for all GMM models
            log_probs = np.zeros(31)
            
            for i, model_class_id in enumerate(range(1, 32)):
                if model_class_id in gmm_models:
                    log_prob = gmm_models[model_class_id].score(features)
                    log_probs[i] = log_prob
                else:
                    log_probs[i] = -np.inf
            
            # Prediction (class with highest log-probability)
            pred = np.argmax(log_probs) + 1
            
            # Filename
            filename = os.path.basename(file_path).rsplit('.', 1)[0]
            
            # Store results
            results.append((filename, pred, log_probs))
    
    return results

def combine_results(audio_file, image_file, output_file="fusion_results.txt", weights=(0.5, 0.5)):
    """
    Combine results from audio and image modalities using weighted fusion
    
    Args:
        audio_file: Path to audio results file
        image_file: Path to image results file
        output_file: Path to output fusion results file
        weights: Tuple of weights for (audio, image) modalities
    """
    # Normalize weights
    total = sum(weights)
    audio_weight, image_weight = weights[0]/total, weights[1]/total
    
    print(f"Combining results with weights: Audio={audio_weight:.2f}, Image={image_weight:.2f}")
    
    # Read audio results
    audio_results = {}
    with open(audio_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 33:  # Filename + class + 31 scores
                filename = parts[0]
                log_probs = [float(p) for p in parts[2:33]]
                audio_results[filename] = log_probs
    
    # Read image results
    image_results = {}
    with open(image_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 33:  # Filename + class + 31 scores
                filename = parts[0]
                log_probs = [float(p) for p in parts[2:33]]
                image_results[filename] = log_probs
    
    # Find common filenames (base filename without extensions)
    audio_basenames = set(k.rsplit('.', 1)[0] if '.' in k else k for k in audio_results.keys())
    image_basenames = set(k.rsplit('.', 1)[0] if '.' in k else k for k in image_results.keys())
    common_basenames = audio_basenames.intersection(image_basenames)
    
    print(f"Found {len(common_basenames)} common files between audio and image results")
    
    # Combine results
    fusion_results = []
    for basename in common_basenames:
        # Find the corresponding keys
        audio_key = next((k for k in audio_results.keys() if k.rsplit('.', 1)[0] == basename or k == basename), None)
        image_key = next((k for k in image_results.keys() if k.rsplit('.', 1)[0] == basename or k == basename), None)
        
        if audio_key and image_key:
            # Get log probabilities
            audio_log_probs = audio_results[audio_key]
            image_log_probs = image_results[image_key]
            
            # Convert log probabilities to probabilities
            audio_probs = np.exp(audio_log_probs)
            image_probs = np.exp(image_log_probs)
            
            # Normalize probabilities
            audio_probs = audio_probs / np.sum(audio_probs) if np.sum(audio_probs) > 0 else audio_probs
            image_probs = image_probs / np.sum(image_probs) if np.sum(image_probs) > 0 else image_probs
            
            # Weighted combination
            combined_probs = audio_weight * audio_probs + image_weight * image_probs
            
            # Convert back to log probabilities
            combined_log_probs = np.log(combined_probs + 1e-10)  # Add small epsilon to avoid log(0)
            
            # Make prediction
            pred = np.argmax(combined_log_probs) + 1  # +1 because classes are 1-indexed
            
            # Store result
            fusion_results.append((basename, pred, combined_log_probs))
    
    # Save fusion results
    with open(output_file, 'w') as f:
        for filename, pred, log_probs in fusion_results:
            # Format: filename prediction log_prob_1 log_prob_2 ... log_prob_31
            line = f"{filename} {pred}"
            for lp in log_probs:
                line += f" {lp:.6f}"
            f.write(line + "\n")
    
    print(f"Fusion results saved to {output_file}")
    return fusion_results

if __name__ == "__main__":
    main()