% LEARNING ACTIVE BASIS MODEL FROM NONALIGNED IMAGES
%% Load in exponential model, mex C codes, and set parameters
% ExponentialModel; % this line is not needed if 'storedExponentialModel' exists
clear; close all; 
load 'storedExponentialModel'; % load in exponential model 
mex CsharedSketch.c;  % learning from aligned images 
mex CgetSUM2.c;  % detect object from multi-resolution images  
mex Csigmoid.c; % sigmoid transformation 
mex CgetMAX1.c; % getting MAX1 maps
mex Ccopy.c; % copy around detect location
epsilon = .1; % allowed correlation between selected Gabors 
subsample = 1; % subsample in computing MAX1 and SUM2 maps   
locationShiftLimit = 3; % shift in normal direction = locationShiftLimit*subsample pixels
orientShiftLimit = 1; % shift in orientation
Correlation = CorrFilter(allFilter, epsilon); % correlation between filters 
numElement = 27; % number of Gabors in active basis   
numIteration = 15;  % number of iterations
useWholeImageOrNot = 1; % whether use the whole image of I{start} or use a portion of it to initialize learning
rotateShiftLimit = 2; numRotate = 2*rotateShiftLimit + 1; % limit of rotation in learning
outputFolder = 'dogRR';
if (exist([outputFolder 'eps']))
    delete([outputFolder 'eps/*.*']); 
else
    mkdir([outputFolder 'eps']); 
end
if (exist([outputFolder 'png']))
    delete([outputFolder 'png/*.*']); 
else
    mkdir([outputFolder 'png']); 
end
if (exist([outputFolder 'Activate']))
    delete([outputFolder 'Activate/*.*']); 
else
    mkdir([outputFolder 'Activate']); 
end
if (exist([outputFolder '.html']))
    delete([outputFolder '.html']); 
end
%% Load in training images and initialize SUM1 maps for learning
resizeFactor = .35; % resize the input images 
imageFolder = 'positiveImage'; % folder of training images  
imageName = dir([imageFolder '/*.jpg']);
numImage = size(imageName, 1); % number of training images 
I = cell(1, numImage); 
for (img = 1 : numImage)
    tmpIm = imread([imageFolder '/' imageName(img).name]); 
    if size(tmpIm,3) == 3
        tmpIm = rgb2gray(tmpIm);
    end
    I{img} = imresize(single(tmpIm), resizeFactor, 'nearest');
end
starting = 1; % initilize by single image learning from I{starting}
startx = 6; endx = 132; % initialize by learning from a given patch
starty = 35; endy = 175; 
tmp = I{starting}; 
if (useWholeImageOrNot>0)  % initialize from the whole image 
    startx = 1; starty = 1; 
    [endx endy] = size(tmp); 
end
I0{1} = tmp(startx:endx, starty:endy); % initialize the template from the starting image 
[sizeTemplatex sizeTemplatey] = size(I0{1}); % size of template
SUM1mapLearn = cell(numImage, numOrient); 
for (img = 1:numImage) 
    for (orient = 1:numOrient)
       SUM1mapLearn{img, orient} = single(zeros(sizeTemplatex, sizeTemplatey)); 
    end
end
%% Generate multi-resolution images and initialize SUM1, MAX1 and SUM2 maps
numResolution = 5;  % number of resolutions to search for in detection stage
originalResolution = 3; % original resolution is the one at which the imresize factor = 1, see 11th line beneath this line 
allSizeMatrixx = zeros(numImage, numResolution); allSizeVectorx = zeros(1, numResolution*numImage); 
allSizeMatrixy = zeros(numImage, numResolution); allSizeVectory = zeros(1, numResolution*numImage); 
ImageMultiResolution = cell(1, numResolution*numImage); 
MAX1map = cell(numResolution*numImage, numOrient);
SUM2map = cell(1, numResolution*numImage); 
translatedTemplate = cell(1, numResolution*numImage); 
for (img = 1:numImage)
  for(j=1:numResolution)
    resolution = .6+(j-1)*.2; % so that .6+(originalResolution-1)*.1 = 1
    ind = (img-1)*numResolution+j; 
    ImageMultiResolution{ind} = imresize(I{img}, resolution, 'nearest');  % images at multiple resolutions
    [sizex, sizey] = size(ImageMultiResolution{ind}); 
    allSizeMatrixx(img, j) = sizex; allSizeVectorx(ind) = sizex; 
    allSizeMatrixy(img, j) = sizey; allSizeVectory(ind) = sizey; 
    for (orient = 1:numOrient) 
        MAX1map{ind, orient} = single(zeros(floor(sizex/subsample), floor(sizey/subsample)));
    end
    SUM2map{ind} = single(zeros(floor((sizex)/subsample), floor((sizey)/subsample)));  
    translatedTemplate{ind} = single(zeros(sizex, sizey));  
  end
end
disp(['start filtering training images at all resolutions']); tic
SUM1mapFind = ApplyFilterfft(ImageMultiResolution, allFilter, localHalfx, localHalfy, thresholdFactor); % filtering images at multiple resolutions
disp(['filtering time: ' num2str(toc) ' seconds']);
Csigmoid(numResolution*numImage, allSizeVectorx, allSizeVectory, numOrient, saturation, SUM1mapFind);
CgetMAX1(numResolution*numImage, allSizeVectorx, allSizeVectory, numOrient,  ...
                     locationShiftLimit, orientShiftLimit, subsample, ...
                     SUM1mapFind, MAX1map);
%% Prepare output variables for learning
selectedOrient = zeros(1, numElement);  % orientation and location of selected Gabors
selectedx = zeros(1, numElement); 
selectedy = zeros(1, numElement); 
selectedlambda = zeros(1, numElement); % weighting parameter for scoring template matching
selectedLogZ = zeros(1, numElement); % normalizing constant
commonTemplate = single(zeros(sizeTemplatex, sizeTemplatey)); % template of active basis 
deformedTemplate = cell(1, numImage); % templates for training images 
for (img = 1 : numImage)
    deformedTemplate{img} = single(zeros(sizeTemplatex, sizeTemplatey));  
end
allSelectedx = zeros(numRotate, numElement); 
allSelectedy = zeros(numRotate, numElement); 
allSelectedOrient = zeros(numRotate, numElement); 
%% Initialize learning from the first image 
numImage0 = 1; locationShiftLimit0 = 0; orientShiftLimit0 = 0; 
allSizex0 = sizeTemplatex + zeros(1, numImage); allSizey0 = sizeTemplatey + zeros(1, numImage); 
SUM1mapLearn0 = cell(1, numOrient); 
ind = (starting-1)*numResolution+originalResolution; 
sizex = allSizeVectorx(ind); sizey = allSizeVectory(ind); 
for (orient = 1 : numOrient)      
       SUM1mapLearn0{1, orient} = single(zeros(sizeTemplatex, sizeTemplatey)); 
       Ccopy(SUM1mapLearn0{1, orient}, SUM1mapFind{ind, orient}, startx-1, starty-1, 0, 0, sizeTemplatex, sizeTemplatey, sizex, sizey, 0); 
end
deformedTemplate0{1} = single(zeros(sizeTemplatex, sizeTemplatey));
allSizex0 = zeros(1, 2) + sizeTemplatex; allSizey0 = zeros(1, 2) + sizeTemplatey; 
for (orient = 1:numOrient)
    SUM1map0Reserved{1, orient} = SUM1mapLearn0{1, orient}+0.; 
end
disp(['start from single image learning']); tic
CsharedSketch(numOrient, locationShiftLimit0, orientShiftLimit0, subsample, ... % about active basis  
       numElement, numImage0, sizeTemplatex, sizeTemplatey, SUM1mapLearn0, ... % about training images 
       halfFilterSize, Correlation, allSymbol(1, :), ... % about filters
       numStoredPoint, storedlambda, storedExpectation, storedLogZ, ... % about exponential model 
       selectedOrient, selectedx, selectedy, selectedlambda, selectedLogZ, ... % learned parameters
       commonTemplate, deformedTemplate0); % learned templates 
disp(['mex-C learning time: ' num2str(toc) ' seconds']);
Outepsgif(-double(commonTemplate), 'nonalign_sym0'); 
RotateTemplate; 
%% Iteration part 1: detect the object in each image 
for(it = 1:numIteration)
    disp(['detection for step ' num2str(it)]);
    for (img=1:numImage)
      if (img~=starting)
        disp(['    start detecting in image ' num2str(img)]); tic
        MMAX2 = -1e10; 
      for (rot = -rotateShiftLimit : rotateShiftLimit)
         r = rot+rotateShiftLimit+1; 
         MAX2score = single(zeros(1, numResolution)); Fx = zeros(1, numResolution); Fy = zeros(1, numResolution); 
         allSizex = allSizeMatrixx(img, :); allSizey = allSizeMatrixy(img, :); 
         CgetSUM2(img, numImage, numResolution, allSizex, allSizey, numOrient, ...
              locationShiftLimit, orientShiftLimit, subsample, ...
              sizeTemplatex, sizeTemplatey, ...
              halfFilterSize, SUM1mapFind, allSymbol(1, :), ...
              numElement, allSelectedOrient(r, :), allSelectedx(r, :), allSelectedy(r, :), selectedlambda, selectedLogZ, ...
              MAX1map, SUM2map, translatedTemplate, MAX2score, Fx, Fy);
         disp(['mex-C finding time: ' num2str(toc) ' seconds']);
         [maxOverResolution j] = max(MAX2score);   % most likely resolution
         if (MMAX2 < maxOverResolution) 
            MMAX2 = maxOverResolution; Mrot = rot; 
         end
       end
        r = Mrot+rotateShiftLimit+1; tic; 
        MAX2score = single(zeros(1, numResolution)); Fx = zeros(1, numResolution); Fy = zeros(1, numResolution); 
        allSizex = allSizeMatrixx(img, :); allSizey = allSizeMatrixy(img, :); 
        CgetSUM2(img, numImage, numResolution, allSizex, allSizey, numOrient, ...
              locationShiftLimit, orientShiftLimit, subsample, ...
              sizeTemplatex, sizeTemplatey, ...
              halfFilterSize, SUM1mapFind, allSymbol(1, :), ...
              numElement, allSelectedOrient(r, :), allSelectedx(r, :), allSelectedy(r, :), selectedlambda, selectedLogZ, ...
              MAX1map, SUM2map, translatedTemplate, MAX2score, Fx, Fy);
        disp(['mex-C finding time: ' num2str(toc) ' seconds']);
        [maxOverResolution j] = max(MAX2score);   % most likely resolution
        
        ind = (img-1)*numResolution+j; 
        Jshow{img} = ImageMultiResolution{ind}+translatedTemplate{ind}*100.;  % overlaid sketch 
        for (orient = 1 : numOrient)
           orient1 = orient - Mrot; 
           if (orient1 > numOrient)
               orient1 = orient1 - numOrient; 
           end
           if (orient1 <= 0)
               orient1 = orient1 + numOrient; 
           end                    
           Ccopy(SUM1mapLearn{img, orient}, SUM1mapFind{ind, orient1}, Fx(j), Fy(j), centerx, centery, sizeTemplatex, sizeTemplatey, allSizex(j), allSizey(j), Mrot*pi/numOrient); 
        end
      end
    end
    for (orient = 1 : numOrient)
       Ccopy(SUM1mapLearn{starting, orient}, SUM1map0Reserved{1, orient}, 0, 0, 0, 0, sizeTemplatex, sizeTemplatey, sizeTemplatex, sizeTemplatey, 0); 
    end
%% Iteration part 2: re-learn the template
    disp(['multiple image learning for step ' num2str(it)]);
    tic
    CsharedSketch(numOrient, locationShiftLimit, orientShiftLimit, subsample, ... % about active basis  
           numElement, numImage, sizeTemplatex, sizeTemplatey, SUM1mapLearn, ... % about training images 
           halfFilterSize, Correlation, allSymbol(1, :), ... % about filters
           numStoredPoint, storedlambda, storedExpectation, storedLogZ, ... % about exponential model 
           selectedOrient, selectedx, selectedy, selectedlambda, selectedLogZ, ... % learned parameters
           commonTemplate, deformedTemplate); % learned templates 
    disp(['mex-C learning time: ' num2str(toc) ' seconds']);
    Outepsgif(-double(commonTemplate), ['nonalign_sym' num2str(it)]); 
    RotateTemplate; 
end
%% display and save results
if (0)
tmp(startx:endx, starty:endy) = tmp(startx:endx, starty:endy) + deformedTemplate{starting}*100.; 
Jshow{starting} = tmp; 
Outepsgif(-double(commonTemplate), 'nonalign_symfinal'); 
for(img=1:numImage)
       Outepsgif(double(Jshow{img}), ['nonalign' num2str(img)]);  % overlaid sketch 
end
end
save 'learnedTemplate' numElement selectedOrient selectedx selectedy selectedlambda selectedLogZ ...
      numOrient halfFilterSize allFilter allSymbol locationShiftLimit orientShiftLimit saturation subsample ...
      resizeFactor numResolution;
%% save figures into eps and png folders  
tmp(startx:endx, starty:endy) = tmp(startx:endx, starty:endy) + deformedTemplate{starting}*100.; 
Jshow{starting} = tmp; 
pic = -double(commonTemplate); showImage; 
saveas(gcf, [outputFolder 'eps/template.eps'], 'eps');
saveas(gcf, [outputFolder 'png/template.png'], 'png');
saveas(gcf, [outputFolder 'Activate/sketch000.png'], 'png');
for(img=1:numImage)
       pic = imread([imageFolder '/' imageName(img).name]); showImage; 
       saveas(gcf, [outputFolder 'eps/input' num2str(100+img) '.eps'], 'epsc');
       saveas(gcf, [outputFolder 'png/input' num2str(100+img) '.png'], 'png');    
       pic = double(Jshow{img}); showImage; 
       saveas(gcf, [outputFolder 'eps/sketch' num2str(100+img) '.eps'], 'eps');
       saveas(gcf, [outputFolder 'png/sketch' num2str(100+img) '.png'], 'png');
       pic = -double(deformedTemplate{img}); showImage; 
       saveas(gcf, [outputFolder 'Activate/sketch' num2str(100+img) '.png'], 'png');
       close all;  
end
%% write the html code for reproducibility page
heightstr = '"height=120> '; 
fid = fopen([outputFolder '.html'], 'wt'); 
outMessage =  ['==> Number of training images = ' num2str(numImage) ...
                          '; Number of elements = ' num2str(numElement) ...
                          '; Length of Gabor = ' num2str(halfFilterSize*2+1) ' pixels' ...
                          '; Range of displacement = ' num2str(locationShiftLimit) ' pixels' ...
                          '; Subsample rate = ' num2str(subsample) ' pixel' ...
                          '; Range of rotation of template = '  num2str(rotateShiftLimit) ' times '  ' pi/' num2str(numOrient)...
                          '<br>'];
fprintf(fid, '%s\n', outMessage);
fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/template' '.png' heightstr]);
fprintf(fid, '%s\n', ['<br>']);
counter = 0; 
for (img = 1 : numImage)
 fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/input' num2str(100+img) '.png' heightstr]);
 fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/sketch' num2str(100+img) '.png' heightstr]);
 counter = counter + 1; 
 if (counter == 3)
     fprintf(fid, '%s\n', ['<br>']);
     counter = 0; 
 end
end

  
  