% SHARED SKETCH ALGORITHM FOR LEARNING ACTIVE BASIS
%% Load in exponential model, mex C codes, and set parameters
% ExponentialModel; % this line is not needed if 'storedExponentialModel' exists
clear; close all; 
load 'storedExponentialModel'; % load in exponential model 
mex Clearn.c; % learning by shared sketch algorithm
mex Csigmoid.c; % local normalization
epsilon = .1; % allowed correlation between selected Gabors 
subsample = 1; % subsample in computing MAX1 maps   
locationShiftLimit = 2; % shift in normal direction = locationShiftLimit*subsample pixels
orientShiftLimit = 1; % shift in orientation
numElement = 15; % number of Gabors in active basis   
numCluster = 2; % number of clusters
numIteration = 15; % numbero f iterations
Correlation = CorrFilter(allFilter, epsilon); % correlation between filters 
outputFolder = 'digits';
mkdir([outputFolder 'eps']); 
mkdir([outputFolder 'png']); 
delete([outputFolder 'eps/*.*']); 
delete([outputFolder 'png/*.*']); 
delete([outputFolder '.html']); 
fid = fopen([outputFolder '.html'], 'wt');
%% Load in training images
%imageOption = 1; resizeFactor = .6; % resize images to specified ratio, and use the common upper-left portion 
imageOption = 2; sizex = 60; sizey = 60; % resize images to specified common height and width
%imageOption = 3; sizey = 100; % resize images to have common width without changing their aspect ratios, and use the common horizontal central line
nD = 10;
nn = 200;
tt = 0; 
for (k = 9:9)
load(['digit' num2str(k) '.mat']);
    for (j = 1 : nn)
        tt = tt + 1; 
        I{tt} = single(imresize(reshape(D(j,:), 28, 28), [sizex sizey], 'nearest'));
    end
end
numImage = tt; 
for (i=1:numImage)
    I{i} = I{i}';
end       
allSizex = zeros(1, numImage)+sizex; 
allSizey = zeros(1, numImage)+sizey; 
disp(['number of training images = ' num2str(numImage)]);
%% Compute SUM1 maps by Gabor filtering
disp(['start filtering']); tic
if (imageOption == 3) 
    [SUM1map, Ifill] = ApplyFilterfftFill(I, allFilter, localOrNot, localHalfx, localHalfy, thresholdFactor); % SUM1 maps by Gabor filtering
    I = Ifill; 
else 
    SUM1map = ApplyFilterfftSame(I, allFilter, localOrNot, localHalfx, localHalfy, -1, thresholdFactor);% SUM1 maps by Gabor filtering
end
Csigmoid(numImage, allSizex, allSizey, numOrient, saturation, SUM1map);
SUM1mapBackUp = mcopy(SUM1map, numImage, numOrient); 
disp(['filtering time: ' num2str(toc) ' seconds']);
%% Prepare output variables for learning
selectedOrient = zeros(1, numElement);  % orientation and location of selected Gabors
selectedx = zeros(1, numElement); 
selectedy = zeros(1, numElement); 
selectedlambda = zeros(1, numElement); % weighting parameter for scoring template matching
selectedLogZ = zeros(1, numElement); % normalizing constant
commonTemplate = single(zeros(sizex, sizey)); % template of active basis 
deformedTemplate = cell(1, numImage); % templates for training images 
for (img = 1 : numImage)
    deformedTemplate{img} = single(zeros(sizex, sizey));  
end
SUM2score = zeros(numImage, 1); % template matching scores for training images 
%% Prepare variables for EM
prob = zeros(1, numCluster); 
SUM2scoreAll = zeros(numImage, numCluster); 
dataWeightAll = rand(numImage,numCluster);
dataWeightAll = dataWeightAll./repmat(sum(dataWeightAll,2),1,numCluster); 
save dataWeightAll dataWeightAll; 
%load dataWeightAll; % remove this line if you do not want strict reproducibility
commonTemplateAll = cell(1, numCluster); 
deformedTemplateAll = cell(numImage, numCluster); 
%% EM iteration
for (it = 1 : numIteration)
    disp(['M-step of iteration ' num2str(it)]);
    for (c = 1:numCluster)
      tic
      dataWeight = dataWeightAll(:, c) + 0.; 
      %% learning for cluster c
      Clearn(numOrient, locationShiftLimit, orientShiftLimit, saturation, subsample, ... % about active basis  
       numElement, numImage, sizex, sizey, SUM1map, SUM1mapBackUp, ... % about training images 
       halfFilterSize, Correlation, allSymbol(1, :), ... % about filters
       numStoredPoint, storedlambda, storedExpectation, storedLogZ, ... % about exponential model 
       selectedOrient, selectedx, selectedy, selectedlambda, selectedLogZ, SUM2score, ... % learned parameters
       commonTemplate, deformedTemplate, dataWeight); % learned templates 
       disp(['   mex-C learning time for cluster ' num2str(c) ' takes ' num2str(toc) ' seconds']);
       SUM2scoreAll(:, c) = SUM2score + 0.; 
       commonTemplateAll{c} = commonTemplate + 0.; 
       for (img = 1 : numImage)
          deformedTemplateAll{img, c} = deformedTemplate{img} + 0.; 
       end
    end
    %% E-step fractional classification
    disp(['E-step of iteration ' num2str(it)]);
    for (c = 1 : numCluster)
       prob(c) = sum(dataWeightAll(:, c))/numImage; 
    end
    for (c = 1 : numCluster)
       for (img = 1 : numImage)
         dataWeightAll(img, c) = prob(c)/sum(prob(1, :).*exp(SUM2scoreAll(img, :)-SUM2scoreAll(img, c))); 
       end
   end
%% save figures into eps and png folders  
if (it==numIteration)
  disp(['Displaying the clusters']);
  fprintf(fid, '%s\n', ['<a href="' outputFolder 'Code.zip">Code and data</a> data source: MNIST<br>']);
  outMessage = ['==> Number of training images = ' num2str(numImage) ...
                        '; Image height and width = '  num2str(sizex) ' by ' num2str(sizey) ' pixels ' ...
                        '; Number of clusters = ' num2str(numCluster) ...
                        '; Number of elements = ' num2str(numElement) ...
                        '; Range of displacement = ' num2str(locationShiftLimit) ' pixels' ...
                        '; Number of iterations = ' num2str(numIteration) ...
                        '; Length of Gabor = ' num2str(halfFilterSize*2+1) ' pixels' ...
                        '; Local normalization or not = ' num2str(localOrNot) ...
                        '; Subsample rate = ' num2str(subsample) ' pixel' ...                     
                        '<br>'];
  fprintf(fid, '%s\n', outMessage);
  disp(outMessage);                             
  fprintf(fid, '%s\n', ['<hr>']);
  heightstr0 = '"height=120> '; heightstr = '"height=80> '; length = 6; 
  for (c = 1:numCluster)      
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/000template' num2str(100+c) '.png' heightstr0]);
  end
  fprintf(fid, '%s\n', ['<hr>']);
  for (c = 1:numCluster)      
    commonTemplate = commonTemplateAll{c}+0.; 
    pic = -double(commonTemplate); showImage; 
    saveas(gcf, [outputFolder 'eps/000template' num2str(100+c) '.eps'], 'eps');
    saveas(gcf, [outputFolder 'png/000template' num2str(100+c) '.png'], 'png');
    pic = commonTemplate*0.+255; showImage; 
    saveas(gcf, [outputFolder 'eps/empty.eps'], 'eps');
    saveas(gcf, [outputFolder 'png/empty.png'], 'png');   
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/000template' num2str(100+c) '.png' heightstr0]);
    fprintf(fid, '%s\n', ['<br>']);
    SUM2score = SUM2scoreAll(:, c)+0.; 
    [SUM2scoresort, ind] = sort(SUM2score, 'descend'); 
    tt = 0; counter = 0; 
    for (i = 1 : numImage)
       this = ind(i); 
       if (dataWeightAll(this, c)==max(dataWeightAll(this, :)))
          tt = tt + 1; close all; 
          pic = I{this}; showImage; 
    saveas(gcf, [outputFolder 'eps/' num2str(100+c) 'I' num2str(1000+tt) '.eps'], 'eps');
    saveas(gcf, [outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) '.png'], 'png');    
    pic = -(deformedTemplateAll{this, c}); showImage; 
    saveas(gcf, [outputFolder 'eps/' num2str(100+c) 'I' num2str(1000+tt) 'sketch.eps'], 'eps');
    saveas(gcf, [outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) 'sketch.png'], 'png');
%% write the html code for reproducibility page
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) '.png' heightstr]);
    fprintf(fid, '%s\n', ['<IMG SRC="' outputFolder 'png/' num2str(100+c) 'I' num2str(1000+tt) 'sketch.png' heightstr]);
           counter = counter + 1; 
           if (counter == length)
               counter = 0; fprintf(fid, '%s\n', ['<br>']);
           end
       end %if
    end %img
    fprintf(fid, '%s\n', ['<hr>']); 
end %c
    fclose(fid);
end %if
end %it


