/* mex-C: 
 * adaboost
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include "mex.h"        /* the algorithm is connect to matlab */
#include "math.h"
#include "matrix.h"
#define PI 3.1415926
#define ABS(x) ((x)>0? (x):(-(x)))
#define MAX(x, y) ((x)>(y)? (x):(y))
#define MIN(x, y) ((x)<(y)? (x):(y))

/* Global variables */
float** posMaps;
float** negMaps;
int nPos, nNeg, nOri;
int sx, sy;
int numGridPoint; /* number of grid points to search for the threshold of the weak classifier */
int halfFilterSize, subSampleStep;
float* thresholds;
bool* polars;
float* dataWeights;
float* errRates;
int maxNumFeature;
int* selectedXs;
int* selectedYs;
int* selectedOs;
float* selectedThresholds;
bool* selectedPolars;
float* lambdas;

/* train the best weak classifier for one data dimension */
void getBestThresholdAndPolar( /* input: (x,y,o) starts from 0 */ int x, int y, int o,
                        /* output: */ float* threshold, bool *polar, float *minErrRate )
{
    int i, k, tmpPolar;
    float stepSize, candidateThres, errRate;
    float minVal, maxVal, val;
    
    /* find min/max values for this data dimension */
    minVal = posMaps[o*nPos+0][y*sx+x];
    maxVal = minVal;
    for( i = 0; i < nPos; ++i )
    {
        val = posMaps[o*nPos+i][y*sx+x];
        if( val > maxVal )
        {
            maxVal = val;
        }
        if( val < minVal )
        {
            minVal = val;
        }
    }
    for( i = 0; i < nNeg; ++i )
    {
        val = negMaps[o*nNeg+i][y*sx+x];
        if( val > maxVal )
        {
            maxVal = val;
        }
        if( val < minVal )
        {
            minVal = val;
        }
    }
    
    *minErrRate = 1;
    /* search over the grid */
    stepSize = ( maxVal - minVal ) / (numGridPoint - 1);
    for( k = 0; k < numGridPoint; ++k )
    {
        candidateThres = minVal + k * stepSize;
        /* compute error rate of the classifier " > candidateThres " */
        errRate = 0;
        for( i = 0; i < nPos; ++i )
        {
            val = posMaps[o*nPos+i][y*sx+x];
            
            if( val <= candidateThres )
            {
                errRate += dataWeights[i];
            }
        }
        for( i = 0; i < nNeg; ++i )
        {
            val = negMaps[o*nNeg+i][y*sx+x];
            if( val > candidateThres )
            {
                errRate += dataWeights[nPos+i];
            }
        }
        
        if( errRate > 1.01 )
            mexErrMsgTxt("warning !! errRate > 1 ");
        
        if( errRate >= 0.5 )
        {
            errRate = 1 - errRate;
            tmpPolar = 0;
        }
        else
        {
            tmpPolar = 1;
        }
        
        
        if( errRate < *minErrRate )
        {
            *minErrRate = errRate;
            if( *minErrRate < -.01 )
                mexErrMsgTxt("warning !! *minErrRate < 0");
            *polar = tmpPolar;
            if( threshold == NULL )
                mexErrMsgTxt("warning !! threshold == NULL");
            *threshold = candidateThres;
        }
    }
}

/* adaboost core function */
void adaboost()
{
    int i, j, currentNumFeature;
    float minErrRate, beta, normConst, val;
    int x, y, o, bestX, bestY, bestO, bestFeatureInd;
    
    currentNumFeature = 0;
    while( currentNumFeature < maxNumFeature )
    {
        /* for each data dimension (subsampled), train a best weak classifier under current data weights */
        j = 0;
        for( x = 1; x < sx; x += 1)
            for( y = 1; y < sy; y += 1)
                for( o = 0; o < nOri; ++o )
                {
                    getBestThresholdAndPolar( x, y, o, thresholds+j, polars+j, errRates+j );
                    ++j;
                }

        /* find the best weak classifier under current data weights */
        minErrRate = 0.51;
        bestFeatureInd = -1;
        j = 0;
        for( x = 1; x < sx; x += 1)
            for( y = 1; y < sy; y += 1)
                for( o = 0; o < nOri; ++o )
                {
                    if( minErrRate > errRates[j] )
                    {
                        minErrRate = errRates[j];
                        bestFeatureInd = j;
                        bestX = x;
                        bestY = y;
                        bestO = o;
                    }
                    ++j;
                }
        
        /* add this weak classifier into the strong classifier */
        selectedXs[currentNumFeature] = bestX;
        selectedYs[currentNumFeature] = bestY;
        selectedOs[currentNumFeature] = bestO;
        selectedThresholds[currentNumFeature] = thresholds[bestFeatureInd];
        selectedPolars[currentNumFeature] = polars[bestFeatureInd];

        /* compute model parameter for the newly selected feature */
        beta = errRates[bestFeatureInd] / ( 1 - errRates[bestFeatureInd] );
        
        lambdas[currentNumFeature] = - log( beta + 1e-6 );
        
        /* update dataweights */
        normConst = 0;
        for( i = 0; i < nPos; ++i )
        {
            val = posMaps[bestO*nPos+i][bestY*sx+bestX];
            if( polars[bestFeatureInd] == ( val > thresholds[bestFeatureInd] ) ) /* weak classifier correctly predicts */
                dataWeights[i] *= beta;
            normConst += dataWeights[i];
        }
        for( i = 0; i < nNeg; ++i )
        {
            val = negMaps[bestO*nNeg+i][bestY*sx+bestX];
            if( polars[bestFeatureInd] == ( val <= thresholds[bestFeatureInd] ) ) /* weak classifier correctly predicts */
                dataWeights[i+nPos] *= beta;
            normConst += dataWeights[i+nPos];
        }
        for( i = 0; i < nPos; ++i )
        {
            dataWeights[i] /= normConst;
        }
        for( i = 0; i < nNeg; ++i )
        {
            dataWeights[i+nPos] /= normConst;
        }
        
        mexPrintf("%d : select %d-th atom, x=%d, y=%d, o=%d, thres=%.3f, polar=%d, errRate=%.3f, lambda=%.3f, beta=%.3f\n",
            currentNumFeature,bestFeatureInd,bestX,bestY,bestO,thresholds[bestFeatureInd],
            polars[bestFeatureInd],errRates[bestFeatureInd],lambdas[currentNumFeature],beta);
        mexEvalString("drawnow");

        /* counter for weak classifier */
        ++currentNumFeature;
    }
}



/* mex function is used to pass on the pointers and scalars from matlab, 
   so that heavy computation can be done by C, which puts the results into 
   some of the pointers. After that, matlab can then use these results. 
   
   So matlab is very much like a managing platform for organizing the 
   experiments, and mex C is like a work enginee for fast computation. */

void mexFunction(int nlhs, mxArray *plhs[], 
                 int nrhs, const mxArray *prhs[])                
{
    int ind, i, x, y, dataDim, bytes_to_copy;
    const mxArray *f;
    mwSize ndim;
    const mwSize* dims;
    mwSize dimsOutput[2];
    void* start_of_pr;
    mxClassID datatype;
 
    /*
	 * input variable 0: posMaps 
	 */
    ndim = mxGetNumberOfDimensions(prhs[0]);
    dims = mxGetDimensions(prhs[0]);
    nPos = dims[0];
    nOri = dims[1]; /* number of orientations */
 
    f = mxGetCell(prhs[0], 0); /* get the first cell element */
    sx = mxGetM(f);
    sy = mxGetN(f);
    
    posMaps = mxCalloc(nPos*nOri, sizeof(*posMaps));   /* fitered positive images */
    for (i=0; i<nPos; ++i)
    {
        for (ind=0; ind<nOri; ++ind)
        {
            f = mxGetCell(prhs[0], ind*nPos+i);
            datatype = mxGetClassID(f);
            if (datatype != mxSINGLE_CLASS)
                mexErrMsgTxt("warning !! float precision required.");
            posMaps[ind*nPos+i] = mxGetPr(f);    /* get pointers to filtered images */   
            /* warning: assignment of pointer to float to pointer to float  */
        }
    }

    /* 
     * input variable 1: negMaps 
     */
    ndim = mxGetNumberOfDimensions(prhs[1]);
    dims = mxGetDimensions(prhs[1]);
    nNeg = dims[0];
    
    negMaps = mxCalloc(nNeg*nOri, sizeof(*negMaps));   /* fitered positive images */
    for (i=0; i<nNeg; ++i)
    {
        for (ind=0; ind<nOri; ++ind)
        {
            f = mxGetCell(prhs[1], ind*nNeg+i);
            datatype = mxGetClassID(f);
            if (datatype != mxSINGLE_CLASS)
                mexErrMsgTxt("warning !! float precision required.");
            negMaps[ind*nNeg+i] = mxGetPr(f);    /* get pointers to filtered images */       
        }
    }
    
    /*
     * input variable 2: data weights (will be changed)
     */
    dataWeights = mxGetPr(prhs[2]);
    
    /*
     * input variable 3: number of grid points 
     */
    numGridPoint = mxGetScalar(prhs[3]);
    
    /* 
     * input variable 4: half filter size 
     */
    halfFilterSize = mxGetScalar(prhs[4]);
    
    /* 
     * input variable 5: subsample step 
     */
    subSampleStep = mxGetScalar(prhs[5]);
    
    /*
     * input variable 6: maximum number of selected features
     */
    maxNumFeature = mxGetScalar(prhs[6]);
    
    /*  count the number of data dimensions */
    dataDim = 0;
    for( x = 1; x < sx; x += 1)
         for( y = 1; y < sy; y += 1)
            ++dataDim;
    dataDim *= nOri;
    
    mexPrintf("input successful. dataDim = %d\n",dataDim);
    
    /*
     * Adaboost algorithm.
     */
    thresholds = mxCalloc(dataDim,sizeof(*thresholds));
    polars = mxCalloc(dataDim,sizeof(*polars));
    lambdas =  mxCalloc(maxNumFeature,sizeof(*lambdas));
    selectedXs = mxCalloc(maxNumFeature,sizeof(*selectedXs));
    selectedYs = mxCalloc(maxNumFeature,sizeof(*selectedYs));
    selectedOs = mxCalloc(maxNumFeature,sizeof(*selectedOs));
    selectedThresholds = mxCalloc(maxNumFeature,sizeof(*selectedThresholds));
    selectedPolars = mxCalloc(maxNumFeature,sizeof(*selectedPolars));
    errRates =  mxCalloc(dataDim,sizeof(*errRates));

    adaboost();
    
    /* =============================================
     * Handle output variables.
     * ============================================= 
     */
    
    dimsOutput[0] = 1; dimsOutput[1] = maxNumFeature;
    
    /*
     * output variable 0: selectedXs
     */
	plhs[0] = mxCreateNumericArray(2, dimsOutput ,mxINT32_CLASS, mxREAL);
    /* populate the real part of the created array */
    start_of_pr = (int*)mxGetData(plhs[0]);
    bytes_to_copy = maxNumFeature * mxGetElementSize(plhs[0]);
    memcpy(start_of_pr,selectedXs,bytes_to_copy);

    /*
     * output variable 1: selectedYs
     */
    plhs[1] = mxCreateNumericArray(2, dimsOutput ,mxINT32_CLASS, mxREAL);
    /* populate the real part of the created array */
    start_of_pr = (int*)mxGetData(plhs[1]);
    bytes_to_copy = maxNumFeature * mxGetElementSize(plhs[1]);
    memcpy(start_of_pr,selectedYs,bytes_to_copy);
    
    /*
     * output variable 2: selectedOs
     */
    plhs[2] = mxCreateNumericArray(2, dimsOutput ,mxINT32_CLASS, mxREAL);
    /* populate the real part of the created array */
    start_of_pr = (int*)mxGetData(plhs[2]);
    bytes_to_copy = maxNumFeature * mxGetElementSize(plhs[2]);
    memcpy(start_of_pr,selectedOs,bytes_to_copy);
    
    /*
     * output variable 3: selectedThresholds
     */
    plhs[3] = mxCreateNumericArray(2, dimsOutput ,mxSINGLE_CLASS, mxREAL);
    /* populate the real part of the created array */
    start_of_pr = (float*)mxGetData(plhs[3]);
    bytes_to_copy = maxNumFeature * mxGetElementSize(plhs[3]);
    memcpy(start_of_pr,selectedThresholds,bytes_to_copy);
    
    /*
     * output variable 4: selectedPolars
     */
    plhs[4] = mxCreateNumericArray(2, dimsOutput ,mxLOGICAL_CLASS, mxREAL);
    /* populate the real part of the created array */
    start_of_pr = (bool*)mxGetData(plhs[4]);
    bytes_to_copy = maxNumFeature * mxGetElementSize(plhs[4]);
    memcpy(start_of_pr,selectedPolars,bytes_to_copy);
    
    /*
     * output variable 5: lambdas
     */
    plhs[5] = mxCreateNumericArray(2, dimsOutput ,mxSINGLE_CLASS, mxREAL);
    /* populate the real part of the created array */
    start_of_pr = (float*)mxGetData(plhs[5]);
    bytes_to_copy = maxNumFeature * mxGetElementSize(plhs[5]);
    memcpy(start_of_pr,lambdas,bytes_to_copy);
}

