
package cz.vutbr.fit.rodos;

import cz.vutbr.fit.rodos.db.DatabaseManager;
import cz.vutbr.fit.rodos.db.rows.AsimRow;
import cz.vutbr.fit.rodos.db.rows.GenericRow;
import cz.vutbr.fit.rodos.db.rows.TableRow;
import cz.vutbr.fit.rodos.db.rows.TollGateRow;
import cz.vutbr.fit.rodos.db.tables.AsimTable;
import cz.vutbr.fit.rodos.db.tables.GenericTable;
import cz.vutbr.fit.rodos.db.tables.TollGateTable;
import general.io.OutputManager;
import libsvm.SvmModel;
import libsvm.SvmNode;
import libsvm.SvmParameters;
import libsvm.SvmProblem;
import libsvm.svm.Svm;

import java.io.IOException;
import java.util.*;

public class Main
{
    private DatabaseManager dm;
    
    public Main() throws Exception
    {
        dm = new DatabaseManager();
    }
    
    public static void main( String[] args )
    {
        Thread.setDefaultUncaughtExceptionHandler( OutputManager.getInstance() );
    
        try
        {
            new Main().run( args );
        }
        catch( Exception e )
        {
            OutputManager.getInstance().printException( e );
        }
    }
    
    private void run( String[] args ) throws Exception
    {
        MainParameters mainParams = new MainParameters( args );
        
        if( mainParams.dbHost != null )
        {
            dm.setHost( mainParams.dbHost );
        }
        
        if( mainParams.dbPort != MainParameters.NO_VALUE )
        {
            dm.setPort( mainParams.dbPort );
        }
    
        if( mainParams.db != null )
        {
            dm.setDatabase( mainParams.db );
        }
        
        if( mainParams.dbUser != null )
        {
            dm.setUser( mainParams.dbUser );
        }
    
        if( mainParams.dbPassword != null )
        {
            dm.setPassword( mainParams.dbPassword );
        }
        
        SvmParameters svmParams = new SvmParameters();
    
        svmParams.svmType = mainParams.svmType;
        svmParams.kernelType = mainParams.kernelType;
        svmParams.gamma = 1.0 / ( mainParams.detectorType == MainParameters.DETECTOR_TOLLGATE ? 5 : 11 ); // 1 / number of features
        svmParams.degree = mainParams.polyDegree; //polynomial kernel polyDegree
        
        if( mainParams.mode == MainParameters.MODE_TRAIN )
        {
            SvmProblem trainingProblem = this.createProblem( mainParams ); //"2017-01-01", "2017-10-01",
    
            if( trainingProblem == null )
            {
                return;
            }
    
            SvmModel model = Svm.train( trainingProblem, svmParams );
    
            Svm.saveModel( mainParams.modelFile, model );
            
            OutputManager.getInstance().printMessage( "Model saved to: " + mainParams.modelFile );
        }
        else if( mainParams.mode == MainParameters.MODE_PREDICT )
        {
            SvmModel model;
            
            try
            {
                model = Svm.loadModel( mainParams.modelFile );
            }
            catch( IOException e )
            {
                throw new Exception( "Cannot read model from file: " + mainParams.modelFile, e );
            }
    
            SvmProblem testingProblem = this.createProblem( mainParams ); //"2017-10-01", "2017-11-01"
    
            if( OutputManager.getInstance().isMathematicaOnly() )
            {
                String out = "{";
        
                for( int i = 0; i < testingProblem.l; ++i )
                {
                    SvmNode[] x = testingProblem.x[i];
                    double y = testingProblem.y[i];
            
                    double y_p = Svm.predict( model, x );
            
                    if( i > 0 )
                    {
                        out += ",";
                    }
            
                    out += "{" + y + "," + y_p + "}";
                }
        
                out += "}";
        
                OutputManager.getInstance().printOutputForMathematica( out );
            }
            else
            {
                for( int i = 0; i < testingProblem.l; ++i )
                {
                    SvmNode[] x = testingProblem.x[i];
                    double y = testingProblem.y[i];
            
                    double y_p = Svm.predict( model, x );
            
                    OutputManager.getInstance().printMessage( "y = " + y + ", y_p = " + y_p );
                }
            }
        }
        else
        {
            OutputManager.getInstance().printErrorMessage( "Nothing to do." );
        }
    }
    
    private SvmProblem createProblem( MainParameters params )
    {
        HashMap<Integer, List<TableRow> > dbData = new HashMap<>();
        
        HashMap<Long, HashMap<Integer, Double>> values = new HashMap<>();
        
        List<Integer> gids = new ArrayList<>();
    
        {
            for( int gid : params.allDetectors )
            {
                if( gid != params.predictDetector )
                {
                    gids.add( gid );
                }
            }
            
            gids.add( params.predictDetector );
        }
    
        if( params.detectorType == MainParameters.DETECTOR_TOLLGATE )
        {
            TollGateTable tollGateTable = new TollGateTable( dm );
        
            for( int gid : gids )
            {
                List<TollGateRow> rows = tollGateTable.getSumByHour( gid, params.timeFrom, params.timeTo );
            
                dbData.put( gid, new ArrayList<>( rows ) );
            }
        }
        else if( params.detectorType == MainParameters.DETECTOR_ASIM )
        {
            AsimTable asimTable = new AsimTable( dm );
        
            for( int gid : gids )
            {
                List<AsimRow> rows = asimTable.getSumByHour( gid, params.timeFrom, params.timeTo );
            
                dbData.put( gid, new ArrayList<>( rows ) );
            }
        }
        else
        {
            GenericTable table = new GenericTable( dm, params.dbTable, params.dbColDetector, params.dbColTime, params.dbColAmount );
    
            for( int gid : gids )
            {
                List<GenericRow> rows = table.getSumByHour( gid, params.timeFrom, params.timeTo );
        
                dbData.put( gid, new ArrayList<>( rows ) );
            }
        }
    
        for( int i = 0; i < gids.size(); ++i )
        {
            int gid = gids.get( i );
            
            for( TableRow row : dbData.get( gid ) )
            {
                long time = 0;
                double value = 0;
    
                if( params.detectorType == MainParameters.DETECTOR_TOLLGATE )
                {
                    TollGateRow tgRow = (TollGateRow) row;
    
                    time = tgRow.start.getTime();
    
                    if( params.columnType == MainParameters.COLUMN_AXES_2 )
                    {
                        value = tgRow.axes2;
                    }
                    else if( params.columnType == MainParameters.COLUMN_AXES_3 )
                    {
                        value = tgRow.axes3;
                    }
                    else if( params.columnType == MainParameters.COLUMN_AXES_4 )
                    {
                        value = tgRow.axes4;
                    }
                }
                else if( params.detectorType == MainParameters.DETECTOR_ASIM )
                {
                    AsimRow aRow = (AsimRow)row;
                    
                    time = aRow.start.getTime();
    
                    if( params.columnType == MainParameters.COLUMN_CATEGORY_1 )
                    {
                        value = aRow.category_1_amount;
                    }
                    else if( params.columnType == MainParameters.COLUMN_CATEGORY_2 )
                    {
                        value = aRow.category_2_amount;
                    }
                    else if( params.columnType == MainParameters.COLUMN_CATEGORY_3 )
                    {
                        value = aRow.category_3_amount;
                    }
                    else if( params.columnType == MainParameters.COLUMN_CATEGORY_4 )
                    {
                        value = aRow.category_4_amount;
                    }
                    else if( params.columnType == MainParameters.COLUMN_CATEGORY_5 )
                    {
                        value = aRow.category_5_amount;
                    }
                    else if( params.columnType == MainParameters.COLUMN_CATEGORY_6 )
                    {
                        value = aRow.category_6_amount;
                    }
                }
                else
                {
                    GenericRow gRow = (GenericRow)row;
    
                    time = gRow.time.getTime();
    
                    value = gRow.amount;
                }
    
                if( !values.containsKey( time ) )
                {
                    values.put( time, new HashMap<>() );
                }
                
                values.get( time ).put( gid, value );
            }
        }
    
        SvmProblem problem = new SvmProblem();
    
        problem.l = values.size();
        problem.x = new SvmNode[problem.l][];
        problem.y = new double[problem.l];
    
        LinkedList<Long> times = new LinkedList<>( values.keySet() );
    
        Collections.sort( times );
    
        for( int i = 0; i < times.size(); ++i )
        {
            long time = times.get( i );
    
            HashMap<Integer, Double> row = values.get( time );
            
            problem.x[i] = new SvmNode[gids.size() - 1];
            
            for( int j = 0; j < gids.size(); ++j )
            {
                int gid = gids.get( j );
    
                double value = 0;
    
                if( row.containsKey( gid ) )
                {
                    value = row.get( gid );
                }
    
                if( j == gids.size() - 1 ) //last one
                {
                    problem.y[i] = value;
                }
                else
                {
                    SvmNode node = new SvmNode();
    
                    node.index = j;
                    node.value = value;
                    
                    problem.x[i][j] = node;
                }
            }
        }
        
        return problem;
    }
}
