% LEARNING ACTIVE BASIS MODEL FROM NONALIGNED IMAGES
%% Load in exponential model, mex C codes, and set parameters
% ExponentialModel; % this line is not needed if 'storedExponentialModel' exists
clear; close all; 
load 'storedExponentialModel'; % load in exponential model 
mex CsharedSketch.c;  % learning from aligned images 
mex Csigmoid.c; % sigmoid transformation 
mex CgetMAX1.c; % getting MAX1 maps
mex CgetSUM2.c;  % detect object from multi-resolution images  
mex CdrawTemplate.c; % draw translated and deformed template
mex Ccopy.c; % copy around detect location
mex Cflip.c; % copy with flipping
epsilon = .1; % allowed correlation between selected Gabors 
locationShiftLimit = 4; % shift in normal direction = locationShiftLimit*subsample pixels
orientShiftLimit = 1; % shift in orientation
Correlation = CorrFilter(allFilter, epsilon); % correlation between filters 
numElement = 60; % number of Gabors in active basis   
numIteration = 10;  % number of iterations
useWholeImageOrNot = 1; % whether use the whole image of I{start} or use a portion of it to initialize learning
rotateShiftLimit = 2; numRotate = 2*rotateShiftLimit + 1; % limit of rotation in learning
flipOrNot = 0; % left-right flip of the template, 0 or 1
outputFolder = 'catFlipRRLS';
if (exist([outputFolder 'eps']))
    delete([outputFolder 'eps/*.*']); 
else
    mkdir([outputFolder 'eps']); 
end
if (exist([outputFolder 'png']))
    delete([outputFolder 'png/*.*']); 
else
    mkdir([outputFolder 'png']); 
end
if (exist([outputFolder 'Activate']))
    delete([outputFolder 'Activate/*.*']); 
else
    mkdir([outputFolder 'Activate']); 
end
if (exist([outputFolder '.html']))
    delete([outputFolder '.html']); 
end
if (exist([outputFolder 'working']))
    delete([outputFolder 'working/*.*']); 
else
    mkdir([outputFolder 'working']); 
end
rangePercent = .2; % maximum percentage of image size shifted from center
%% Load in training images and initialize SUM1 maps for learning
resizeFactor = [120 120]; % resize the input images 
imageFolder = 'positiveImage'; % folder of training images  
imageName = dir([imageFolder '/*.jpg']);
numImage = size(imageName, 1); % number of training images 
I = cell(1, numImage); 
for (img = 1 : numImage)
    tmpIm = imread([imageFolder '/' imageName(img).name]); 
    if size(tmpIm,3) == 3
        tmpIm = rgb2gray(tmpIm);
    end
    I{img} = imresize(single(tmpIm), resizeFactor, 'nearest');
end
[sizeTemplatex sizeTemplatey] = size(I{1}); % size of template
%% Generate multi-resolution images and initialize SUM1, MAX1 and SUM2 maps
numResolution = 5;  % number of resolutions to search for in detection stage
originalResolution = 3; % original resolution is the one at which the imresize factor = 1, see 11th line beneath this line 
allSizex = zeros(1, numResolution); allSizey = zeros(1, numResolution); 
ImageMultiResolution = cell(1, numResolution); 
MAX2score = single(zeros(1, numResolution));  % maximum log-likelihood score at each resolution
allFx = zeros(1, numResolution); allFy = zeros(1, numResolution); % detected location at each resolution   
MAX1map = cell(numResolution, numOrient);
SUM2map = cell(1, numResolution); 
translatedTemplate = cell(1, numResolution*numImage); 
for (img = 1:numImage)
  for(resolution=1:numResolution)
    resizeFactor = .8+(resolution-1)*.1; % so that .8+(originalResolution-1)*.1 = 1
    ImageMultiResolution{resolution} = imresize(I{img}, resizeFactor, 'nearest');  % images at multiple resolutions
    [sizex, sizey] = size(ImageMultiResolution{resolution}); 
    allSizex(resolution) = sizex; allSizey(resolution) = sizey; 
    for (orient = 1:numOrient) 
        MAX1map{resolution, orient} = single(zeros(sizex, sizey));
    end
    SUM2map{resolution} = single(zeros(sizex, sizey)); 
    translatedTemplate{resolution} = single(zeros(sizex, sizey));  
  end
  disp(['======> start filtering and maxing image ' num2str(img)]); tic
  SUM1mapFind = ApplyFilterfft(ImageMultiResolution, allFilter, localHalfx, localHalfy, thresholdFactor); % filtering images at multiple resolutions
  Csigmoid(numResolution, allSizex, allSizey, numOrient, saturation, SUM1mapFind);
  CgetMAX1(numResolution, allSizex, allSizey, numOrient,  ...
                     locationShiftLimit, orientShiftLimit, ...
                     SUM1mapFind, MAX1map);
  mapName = [outputFolder 'working/SUMMAXmap' 'image' num2str(img)];   
  save(mapName, 'ImageMultiResolution', 'SUM1mapFind', 'MAX1map', 'SUM2map', 'translatedTemplate', 'allSizex', 'allSizey');                
  disp(['filtering and maxing time: ' num2str(toc) ' seconds']);
end
%% Prepare output variables for learning
selectedOrient = zeros(1, numElement);  % orientation and location of selected Gabors
selectedx = zeros(1, numElement); 
selectedy = zeros(1, numElement); 
selectedlambda = zeros(1, numElement); % weighting parameter for scoring template matching
selectedLogZ = zeros(1, numElement); % normalizing constant
commonTemplate = single(zeros(sizeTemplatex, sizeTemplatey)); % template of active basis 
deformedTemplate = cell(1, numImage); % templates for training images 
for (img = 1 : numImage)
    deformedTemplate{img} = single(zeros(sizeTemplatex, sizeTemplatey));  
end
allSelectedx = zeros(numRotate, numElement); 
allSelectedy = zeros(numRotate, numElement); 
allSelectedOrient = zeros(numRotate, numElement); 
%% Initialize learning 
SUM1mapLearn = cell(numImage, numOrient); 
for (img = 1:numImage) 
    for (orient = 1:numOrient)
       SUM1mapLearn{img, orient} = single(zeros(sizeTemplatex, sizeTemplatey)); 
    end
end
SUM1mapLearnTmp = cell(1, numOrient); 
numImage0 = 1; locationShiftLimit0 = 0; orientShiftLimit0 = 0; 
allSizex0 = sizeTemplatex + zeros(1, numImage); allSizey0 = sizeTemplatey + zeros(1, numImage); 
SUM1mapLearn0 = cell(1, numOrient); 
ind = originalResolution; 
for (img = 1:numImage) 
mapName = [outputFolder 'working/SUMMAXmap' 'image' num2str(img)]; 
load(mapName); 
for (orient = 1 : numOrient)      
       Ccopy(SUM1mapLearn{img, orient}, SUM1mapFind{ind, orient}, 0, 0, 0, 0, sizeTemplatex, sizeTemplatey, sizeTemplatex, sizeTemplatey, 0); 
end
end
disp(['**************** start learning initial template']); tic
CsharedSketch(numOrient, locationShiftLimit, orientShiftLimit, ... % about active basis  
           numElement, numImage, sizeTemplatex, sizeTemplatey, SUM1mapLearn, ... % about training images 
           halfFilterSize, Correlation, allSymbol(1, :), ... % about filters
           numStoredPoint, storedlambda, storedExpectation, storedLogZ, ... % about exponential model 
           selectedOrient, selectedx, selectedy, selectedlambda, selectedLogZ, ... % learned parameters
           commonTemplate, deformedTemplate); % learned templates 
    disp(['mex-C learning time: ' num2str(toc) ' seconds']);
    Outepsgif(-double(commonTemplate), ['nonalign_sym' num2str(0)]); 
    RotateTemplate; 
%% Iteration part 1: detect the object in each image 
for(it = 1:numIteration)
    disp(['**************** detection for iteration ' num2str(it)]);
    for (img=1:numImage)
        disp(['======> start detecting in image ' num2str(img)]); tic
        mapName = [outputFolder 'working/SUMMAXmap' 'image' num2str(img)]; 
        load(mapName); 
        MMAX2 = -1e10; 
        for (flip = 0 : flipOrNot)
          for (rot = -rotateShiftLimit : rotateShiftLimit)
            r = rot+rotateShiftLimit+1 + (rotateShiftLimit*2+1)*flip;  
            CgetSUM2(numResolution, allSizex, allSizey, numOrient, ...
                     numElement, allSelectedOrient(r, :), allSelectedx(r, :), allSelectedy(r, :), selectedlambda, selectedLogZ, ...
                     MAX1map, SUM2map, MAX2score, allFx, allFy, rangePercent);    
            [maxOverResolution ind] = max(MAX2score);   % most likely resolution
            if (MMAX2 < maxOverResolution) 
               MMAX2 = maxOverResolution; Mrot = rot; Mflip = flip; 
               Mind = ind; MFx = allFx(ind); MFy = allFy(ind); 
            end
          end
        end
        if (it == numIteration)
          r = Mrot+rotateShiftLimit+1 + (rotateShiftLimit*2+1)*Mflip;  
          CdrawTemplate(numResolution, allSizex, allSizey, numOrient, ...
              locationShiftLimit, orientShiftLimit, ...
              halfFilterSize, SUM1mapFind, allSymbol(1, :), ...
              numElement, allSelectedOrient(r, :), allSelectedx(r, :), allSelectedy(r, :), ...
              translatedTemplate, Mind, MFx, MFy);
          Jshow{img} = ImageMultiResolution{Mind}*0-translatedTemplate{Mind}*100.;  % overlaid sketch 
        end
        disp(['mex-C finding time: ' num2str(toc) ' seconds']);
        for (orient = 1 : numOrient)
           orient1 = orient - Mrot; 
           if (orient1 > numOrient)
               orient1 = orient1 - numOrient; 
           end
           if (orient1 <= 0)
               orient1 = orient1 + numOrient; 
           end                    
           Ccopy(SUM1mapLearn{img, orient}, SUM1mapFind{Mind, orient1}, MFx, MFy, centerx, centery, sizeTemplatex, sizeTemplatey, allSizex(Mind), allSizey(Mind), Mrot*pi/numOrient);  
        end
        if (Mflip > 0)
             for (orient = 1 : numOrient)
                 SUM1mapLearnTmp{1, orient} = SUM1mapLearn{img, orient}+0.; 
             end
             for (orient = 1 : numOrient)
                 if (orient-1>0)
                     orient2 = numOrient - (orient-1) + 1; 
                 else
                     orient2 = orient;   
                 end
                 Cflip(SUM1mapLearn{img, orient2}, SUM1mapLearnTmp{1, orient}, sizeTemplatex, sizeTemplatey); 
             end
        end
    end
%% Iteration part 2: re-learn the template
    disp(['**************** multiple image learning for iteration ' num2str(it)]); tic    
    CsharedSketch(numOrient, locationShiftLimit, orientShiftLimit, ... % about active basis  
           numElement, numImage, sizeTemplatex, sizeTemplatey, SUM1mapLearn, ... % about training images 
           halfFilterSize, Correlation, allSymbol(1, :), ... % about filters
           numStoredPoint, storedlambda, storedExpectation, storedLogZ, ... % about exponential model 
           selectedOrient, selectedx, selectedy, selectedlambda, selectedLogZ, ... % learned parameters
           commonTemplate, deformedTemplate); % learned templates 
    disp(['mex-C learning time: ' num2str(toc) ' seconds']);
    Outepsgif(-double(commonTemplate), ['nonalign_sym' num2str(it)]); 
    RotateTemplate; 
end
%% save figures into eps and png folders  
pic = -double(commonTemplate); showImage; 
saveas(gcf, [outputFolder 'eps/template.eps'], 'eps');
saveas(gcf, [outputFolder 'png/template.png'], 'png');
saveas(gcf, [outputFolder 'Activate/sketch000.png'], 'png');
for(img=1:numImage)
       pic = imread([imageFolder '/' imageName(img).name]); showImage; 
       saveas(gcf, [outputFolder 'eps/input' num2str(100+img) '.eps'], 'epsc');
       saveas(gcf, [outputFolder 'png/input' num2str(100+img) '.png'], 'png');    
       pic = double(Jshow{img}); showImage; 
       saveas(gcf, [outputFolder 'eps/sketch' num2str(100+img) '.eps'], 'eps');
       saveas(gcf, [outputFolder 'png/sketch' num2str(100+img) '.png'], 'png');
       pic = -double(deformedTemplate{img}); showImage; 
       saveas(gcf, [outputFolder 'Activate/sketch' num2str(100+img) '.png'], 'png');
       close all;  
end
%% write the html code for reproducibility page
heightstr = '"height=80> '; 
fid = fopen([outputFolder '.html'], 'wt'); 
fprintf(fid, '%s\n', ['<a href="http://www.stat.ucla.edu/~ywu/AB/' outputFolder 'Code.zip">Code and data</a> <br>']);
outMessage =  ['==> Number of training images = ' num2str(numImage) ...
                          '; Number of elements = ' num2str(numElement) ...
                          '; Length of Gabor = ' num2str(halfFilterSize*2+1) ' pixels' ...
                          '; Range of displacement = ' num2str(locationShiftLimit) ' pixels' ...
                          '; Range of rotation of template = '  num2str(rotateShiftLimit) ' times '  ' pi/' num2str(numOrient)...
                          '<br>'];
fprintf(fid, '%s\n', outMessage);
fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/template' '.png' heightstr]);
fprintf(fid, '%s\n', ['<br>']);
counter = 0; 
for (img = 1 : numImage)
 fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/input' num2str(100+img) '.png' heightstr]);
 fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/sketch' num2str(100+img) '.png' heightstr]);
 counter = counter + 1; 
 if (counter == 3)
     fprintf(fid, '%s\n', ['<br>']);
     counter = 0; 
 end
end
if (exist([outputFolder 'working']))
    delete([outputFolder 'working/*.*']); 
end
  
  