% get the reconstruction error on testing images
% run this testing file after we ran the standard learning procedure. 
% so we already have syn_mats cell, config

%% read the testing images and record them into imdb
test_category = 'test_300';
config = test_nfa_config(test_category, config);

imgCell_test = read_images_test(config, net);
[imdb_test, getBatch_test] = convert2imdb(imgcell2mat(imgCell_test));

%% infer latent factor z for the testing images (syn_mats_test)
opts.train = find(imdb_test.images.set==1);
opts.batchSize = config.BatchSize_test;
opts.batchSize = min(opts.batchSize, numel(opts.train));
opts.numSubBatches = 1 ;
opts.gpus = config.gpus;
opts.prefetch = false;

num_syns_test = ceil(numel(opts.train) / opts.batchSize) * numel(opts.gpus);
syn_mats_test = cell(1, num_syns_test);
%slowness_mats_test = cell(1, num_syns_test);

[net, syn_mats_test] = infer_nfa(opts, getBatch_test, opts.train, imdb_test, net, syn_mats_test, config);

% compute the l2 loss, use the loss computation function in train_model_nfa
loss_nfa = compute_loss_l2(opts, imdb_test, getBatch_test, opts.train, net, syn_mats_test);

% here, again, we only consider one testing batch (100)
infer_test_z = syn_mats_test{1};
im_test_infer = vl_simplenn(net, infer_test_z, [], [], ...
             'conserveMemory', 1, ...
             'cudnn', 1);
I_test_infer = im_test_infer(end).x;
diff2_sqrt = sqrt((I_test_infer - imdb_test.images.data).^2);
error_recon = mean(diff2_sqrt(:))
draw_figures_reconstruct(config, I_test_infer,'NFA_300');
fprintf('The testing reconstruct error is %02d\n', loss_nfa);

% write the inferred_test_images into the folder
%num_test_infer = size(I_test_infer, 4);
%for i_infer = 1:num_test_infer
%    gLow = min( reshape(I_test_infer(:,:,:,i_infer), [],1));
%    gHigh = max(reshape(I_test_infer(:,:,:,i_infer), [],1));
%    I_test_infer(:,:,:,i_infer) = (I_test_infer(:,:,:,i_infer)-gLow) / (gHigh - gLow);
%    imwrite(I_test_infer(:,:,:,i_infer), [config.test_folder, num2str(i_infer, 'reconstruct_image_%02d'), '.png']);
%end

% get the reconstruction error of NFA
%diff2_sqrt = sqrt((I_test_infer - imdb_test.images.data).^2);
%error_recon = mean(diff2_sqrt(:));


%% compare to the PCA
% now imdb.images.data is the training data, and the imdb_test.images.data
% is the testing data
% convert the 4D format to 2D matrix
train_Y_scale = [];
for i = 1:size(imdb.images.data, 4)
    current_img = imdb.images.data(:,:,:,i);
    train_Y_scale = [train_Y_scale, current_img(:)];
end
test_Y_scale = [];
for i = 1:size(imdb_test.images.data, 4)
    current_img = imdb_test.images.data(:,:,:,i);
    test_Y_scale = [test_Y_scale, current_img(:)];
end

% or we can directly read the images

files_test = dir([config.inPath_test '*.jpg']);
numImages_test = 0;
if numImages_test ~= length(files_test)
   imgCell_test = cell(1, length(files_test));
   test_Y = [];
   for iImg = 1:length(files_test)
       img_large = single(imread(fullfile(config.inPath_test, files_test(iImg).name)));
       h = round((size(img_large, 1) - config.cropped_sz)/2);       
       w = round((size(img_large, 2) - config.cropped_sz)/2); 
       cropped_img = img_large(h:h+config.cropped_sz-1, w:w+config.cropped_sz-1, :);
       img_large = cropped_img;
       img = imresize(img_large, [config.sx,config.sy]);
       imgCell_test{iImg} = img;
       test_Y = [test_Y, img(:)];
   end
end

% may change this
%inPath_train = ['../../Image/', config.categoryName '/'];
files_train = dir([ config.inPath '*.jpg']);
numImages_train = 0;
if numImages_train ~= length(files_train)
   imgCell_train = cell(1, length(files_train));
   train_Y = [];
   for iImg = 1:1000
       img_large = single(imread(fullfile(config.inPath, files_train(iImg).name)));
       h = round((size(img_large, 1) - config.cropped_sz)/2);       
       w = round((size(img_large, 2) - config.cropped_sz)/2); 
       cropped_img = img_large(h:h+config.cropped_sz-1, w:w+config.cropped_sz-1, :);
       img_large = cropped_img;
       img = imresize(img_large, [config.sx,config.sy]);
       imgCell_train{iImg} = img;
       train_Y = [train_Y, img(:)];
   end
end



% learn the basis Chat
n = config.z_dim; % pca dimension
tau_train = size(train_Y, 2); Ymean_train = mean(train_Y, 2);
[U, S, V] = svd(train_Y-Ymean_train*ones(1, tau_train), 0);
Chat = U(:, 1:n); Xhat_train = S(1:n, 1:n)*V(:, 1:n)';

% project the testing data into basis Chat and do the reconstruction
% the obtained coefficient is C(C'C)^(-1) C' I_test
Ymean_test = mean(test_Y, 2);tau_test = size(test_Y, 2);
Y_test_submean = test_Y - Ymean_test*ones(1,tau_test);
recon_Y_submean = Chat * inv((Chat' * Chat)) * Chat' * Y_test_submean;
recon_Y = recon_Y_submean + Ymean_test*ones(1, tau_test);

% convert the recon_Y to range [-1,1]
recon_Y_scale = zeros(size(recon_Y), 'single');
for i = 1:size(recon_Y,2)
    current_im = recon_Y(:,i);
    recon_Y_scale(:,i) = 2*(current_im - min(current_im(:)))/(max(current_im(:))-min(current_im(:)))-1;
end

% get the reconstruction error for PCA and then draw the figure
diff2_sqrt_pca = sqrt((recon_Y_scale - test_Y_scale).^2);
loss_pca = mean(diff2_sqrt_pca(:));
fprintf('The testing reconstruct error for PCA is %02d\n', loss_pca);

%transform the recon_Y_scale from 2D to 4D matconvnet format
recon_Y_4D = zeros(config.sx, config.sy, 3, size(recon_Y_scale,2), 'single');
for i = 1:size(recon_Y_scale,2)
    recon_Y_4D(:,:,:,i) = reshape(recon_Y_scale(:,i), [config.sx, config.sy, 3]);
end

draw_figures_reconstruct(config, recon_Y_4D, 'PCA_300_200dim');
draw_figures_reconstruct(config, imdb_test.images.data,'original');





