% LEARNING ACTIVE BASIS MODEL FROM NONALIGNED IMAGES
%% Load in exponential model, mex C codes, and set parameters
% ExponentialModel; % this line is not needed if 'storedExponentialModel' exists
clear; close all; 
load 'storedExponentialModel'; % load in exponential model 
mex CsharedSketch.c;  % learning from aligned images 
mex Csigmoid.c; % sigmoid transformation 
mex CgetMAX1.c; % getting MAX1 maps
mex CgetSUM2.c;  % detect object from multi-resolution images  
mex CdrawTemplate.c; % draw translated and deformed template
mex Ccopy.c; % copy around detect location
mex Cflip.c; % copy with flipping
epsilon = .1; % allowed correlation between selected Gabors 
locationShiftLimit = 1; % shift in normal direction = locationShiftLimit*subsample pixels
orientShiftLimit = 2; % shift in orientation
Correlation = CorrFilter(allFilter, epsilon); % correlation between filters 
numElementLimit = 30; % limit on number of Gabors in active basis   
responseLimit = 0.; % lower limit on average response of selected basis elements, you can set it to zero if you really want numElementLimit elements
numElementPointer = zeros(1, 2); % actual number selected, we use a pointer just to pass the number 
numIteration = 5;  % number of iterations in multiple starting
finalNumIteration = 10; % number of iterations in final run
rotateShiftLimit = 1; numRotate = 2*rotateShiftLimit + 1; % limit of rotation in learning
flipOrNot = 0; % left-right flip of the template, 0 or 1
numCluster = 50; % number of clusters
outputFolder = 'RepclusterDigitRR';
if (exist([outputFolder 'eps']))
    delete([outputFolder 'eps/*.*']); 
else
    mkdir([outputFolder 'eps']); 
end
if (exist([outputFolder 'png']))
    delete([outputFolder 'png/*.*']); 
else
    mkdir([outputFolder 'png']); 
end
if (exist([outputFolder '.html']))
    delete([outputFolder '.html']); 
end
if (exist([outputFolder 'working']))
    delete([outputFolder 'working/*.*']); 
else
    mkdir([outputFolder 'working']); 
end
rangePercent = .1; % maximum percentage of image size shifted from center
numRepetition = 10; % number of multiple-starting 
%% Load in training images and initialize SUM1 maps for learning
disp(['get images']);
sizex = 70; sizey = 70; 
nD = 10;
nn = 500;
tt = 0; 
for (k = 0:9)
load(['digit' num2str(k) '.mat']);
    for (j = 1 : nn)
        tt = tt + 1; 
        I{tt} = single(imresize(reshape(D(j,:), 28, 28), [sizex sizey], 'nearest'));
    end
end
numImage = tt; 
for (i=1:numImage)
    I{i} = I{i}';
end       
disp(['number of training images = ' num2str(numImage)]);
sizeTemplatex = sizex; sizeTemplatey = sizey; 
allSizex = zeros(1, numImage)+sizex; 
allSizey = zeros(1, numImage)+sizey; 
sizeTemplatex = sizex; sizeTemplatey = sizey; 
%% Generate multi-resolution images and initialize SUM1, MAX1 and SUM2 maps
numResolution = 5;  % number of resolutions to search for in detection stage
originalResolution = 3; % original resolution is the one at which the imresize factor = 1, see 11th line beneath this line 
allSizex = zeros(1, numResolution); allSizey = zeros(1, numResolution); 
ImageMultiResolution = cell(1, numResolution); 
MAX2score = single(zeros(1, numResolution));  % maximum log-likelihood score at each resolution
allFx = zeros(1, numResolution); allFy = zeros(1, numResolution); % detected location at each resolution   
MAX1map = cell(numResolution, numOrient);
SUM2map = cell(1, numResolution); 
translatedTemplate = cell(1, numResolution*numImage); 
for (img = 1:numImage)
  for(resolution=1:numResolution)
    resizeFactor = .8+(resolution-1)*.1; % so that .8+(originalResolution-1)*.1 = 1
    ImageMultiResolution{resolution} = imresize(I{img}, resizeFactor, 'nearest');  % images at multiple resolutions
    [sizex, sizey] = size(ImageMultiResolution{resolution}); 
    allSizex(resolution) = sizex; allSizey(resolution) = sizey; 
    for (orient = 1:numOrient) 
        MAX1map{resolution, orient} = single(zeros(sizex, sizey));
    end
    SUM2map{resolution} = single(zeros(sizex, sizey)); 
    translatedTemplate{resolution} = single(zeros(sizex, sizey));  
  end
  disp(['======> start filtering and maxing image ' num2str(img)]); tic
  SUM1mapFind = ApplyFilterfft(ImageMultiResolution, allFilter, localHalfx, localHalfy, thresholdFactor); % filtering images at multiple resolutions
  Csigmoid(numResolution, allSizex, allSizey, numOrient, saturation, SUM1mapFind);
  CgetMAX1(numResolution, allSizex, allSizey, numOrient,  ...
                     locationShiftLimit, orientShiftLimit, ...
                     SUM1mapFind, MAX1map);
  SUM1mapName = [outputFolder 'working/SUM1map' 'image' num2str(img)];   
  save(SUM1mapName, 'ImageMultiResolution', 'SUM1mapFind',  'translatedTemplate', 'allSizex', 'allSizey');     
  MAX1mapName = [outputFolder 'working/MAX1map' 'image' num2str(img)];
  save(MAX1mapName, 'MAX1map', 'SUM2map', 'allSizex', 'allSizey');     
  disp(['filtering and maxing time: ' num2str(toc) ' seconds']);
end
MrotAll = zeros(numImage, numCluster); MflipAll = zeros(numImage, numCluster); 
MindAll = zeros(numImage, numCluster); MFxAll = zeros(numImage, numCluster); MFyAll = zeros(numImage, numCluster); 
%% Prepare output variables for learning
selectedOrient = zeros(1, numElementLimit);  % orientations  of selected Gabors
selectedx = zeros(1, numElementLimit); % locations of selected Gabors
selectedy = zeros(1, numElementLimit); 
selectedlambda = zeros(1, numElementLimit); % weighting parameter for scoring template matching
selectedLogZ = zeros(1, numElementLimit); % normalizing constant
commonTemplate = single(zeros(sizeTemplatex, sizeTemplatey)); % template of active basis 
deformedTemplate = cell(1, numImage); % templates for training images 
for (img = 1 : numImage)
    deformedTemplate{img} = single(zeros(sizeTemplatex, sizeTemplatey));  
end
allSelectedx = zeros(numRotate, numElementLimit); 
allSelectedy = zeros(numRotate, numElementLimit); 
allSelectedOrient = zeros(numRotate, numElementLimit); 
SUM1mapLearn = cell(numImage, numOrient); 
for (img = 1:numImage) 
    for (orient = 1:numOrient)
       SUM1mapLearn{img, orient} = single(zeros(sizeTemplatex, sizeTemplatey)); 
    end
end
%% initialize EM from random cluster
MSUMofMAX2 = -1e10; 
for (rep = 1 : numRepetition)
disp(['LEARNING REPETITION ' num2str(rep)]);
MAX2scoreAll = rand(numImage, numCluster);
InitialName = [outputFolder 'working/Initial' 'rep' num2str(rep)];   
save(InitialName, 'MAX2scoreAll');     
ABlearnRep; 
SUMofMAX2 = 0.; 
for (img = 1 : numImage)
     SUMofMAX2 = SUMofMAX2 + max(MAX2scoreAll(img, :)); 
end
if (MSUMofMAX2 < SUMofMAX2)
    MSUMofMAX2 = SUMofMAX2; 
    Mrep = rep; 
end
end
InitialName = [outputFolder 'working/Initial' 'rep' num2str(Mrep)];   
load(InitialName); 
numIteration = finalNumIteration;  
ABlearnRep; 
%% save figures into eps and png folders  
  disp(['Displaying the clusters']);
  fid = fopen([outputFolder '.html'], 'wt');
  fprintf(fid, '%s\n', ['<a href="http://www.stat.ucla.edu/~ywu/AB/' outputFolder 'Code.zip">Code and data</a> <br>']);
  outMessage = ['==> Number of multiple starting = ' num2str(numRepetition) ...
                        '; Number of training images = ' num2str(numImage) ...
                        '; Image height and width = '  num2str(sizeTemplatex) ' by ' num2str(sizeTemplatey) ' pixels ' ...
                        '; Number of clusters = ' num2str(numCluster) ...
                        '; Number of elements = ' num2str(numElement) ...
                        '; Range of displacement = ' num2str(locationShiftLimit) ' pixels' ...
                        '; Number of iteration = ' num2str(numIteration) ...
                        '; Length of Gabor = ' num2str(halfFilterSize*2+1) ' pixels' ...      
                        '; Range of rotation of template = '  num2str(rotateShiftLimit) ' times '  ' pi/' num2str(numOrient)...
                        '; Flip of template = ' num2str(flipOrNot)...
                        '<br>'];
  fprintf(fid, '%s\n', outMessage);
  disp(outMessage);                             
  fprintf(fid, '%s\n', ['<hr>']);
  heightstr0 = '"height=60> '; heightstr = '"height=50> '; length = 4; 
  
  for (c = 1:numCluster)      
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/000template' num2str(100+c) '.png' heightstr0]);
  end
  fprintf(fid, '%s\n', ['<hr>']);
  
  for (c = 1:numCluster)      
    templateName = [outputFolder 'working/template' 'cluster' num2str(c)];   
    load(templateName);       
    pic = -double(commonTemplate); showImage; 
    saveas(gcf, [outputFolder 'png/000template' num2str(100+c) '.png'], 'png');
    pic = commonTemplate*0.+255; showImage; 
    saveas(gcf, [outputFolder 'png/empty.png'], 'png');   
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/000template' num2str(100+c) '.png' heightstr0]);
    fprintf(fid, '%s\n', ['<br>']);
    MAX2scorec = MAX2scoreAll(:, c)+0.; 
    [MAX2scoresort, ind] = sort(MAX2scorec, 'descend'); 
    tt = 0; counter = 0; 
    for (i = 1 : numImage)
       this = ind(i); 
       if (MAX2scoreAll(this, c)==max(MAX2scoreAll(this, :)))
          tt = tt + 1; close all; 
          pic = I{this}; showImage; 
    saveas(gcf, [outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) '.png'], 'png');    
    pic = Jshow{this}; showImage; 
    saveas(gcf, [outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) 'sketch.png'], 'png');
   %% write the html code for reproducibility page
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) '.png' heightstr]);
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) 'sketch.png' heightstr]);
           counter = counter + 1; 
           if (counter == length)
               counter = 0; fprintf(fid, '%s\n', ['<br>']);
           end
       end %if
    end %img
    fprintf(fid, '%s\n', ['<hr>']); 
    end %c
    fclose(fid);
