function res = compute_stat_STGConvNet(net, sequence, conserveMemory)


if nargin < 3
    conserveMemory = true;
end

numLayers=numel(net.layers);   % number of Convolutional layers.

numVideos=size(sequence,5);

%% compute statistics of the input sequence
res = struct(...
    'x', cell(1,numLayers+1), ...              % feature map
    'indicator', cell(1,numLayers+1), ...      % activation maps
    'stat_weights', cell(1,numLayers+1), ...   % df / dw
    'stat_bias', cell(1,numLayers+1),...       % df / db
    'size_row', cell(1,numLayers+1),...          
    'size_col', cell(1,numLayers+1),...           
    'size_time', cell(1,numLayers+1)); 

res(1).x = sequence;  % the first layer of response map is the observed signal

% for l=1:numLayers
%     
%     
%     res(l+1).x = mex_conv3d(res(l).x, net.layers{l}.filters, net.layers{l}.bias, 'pad', net.layers{l}.pad, 'stride', net.layers{l}.stride);
%     res(l+1).x = vl_nnrelu(res(l+1).x);
%     res(l+1).indicator = vl_nnrelu(res(l+1).x,  gpuArray(ones(size(res(l+1).x),'single')));
%     
%     
%     %         if l==numLayers
%     %             res_obs(l+1).indicator = gpuArray(ones(1,'single'));
%     %         end
%     
%     
%     
%     [~, res(l+1).stat_weights, res(l+1).stat_bias] = mex_conv3d(res(l).x, net.layers{l}.filters, net.layers{l}.bias, res(l+1).indicator, 'pad', net.layers{l}.pad, 'stride', net.layers{l}.stride);
%     
%     % clear res_obs(l).x indicator_FirstLayer
%     
%     res(l+1).stat_weights = gather(res(l+1).stat_weights * (1/numVideos));
%     res(l+1).stat_bias = gather(res(l+1).stat_bias * (1/numVideos));
%     
%     
% end



for l=1:numLayers
    res(l+1).x = mex_conv3d(res(l).x, net.layers{l}.filters, net.layers{l}.bias, 'pad', net.layers{l}.pad, 'stride', net.layers{l}.stride);
    
    res(l+1).size_row = size(res(l+1).x,1);
    res(l+1).size_col = size(res(l+1).x,2);
    res(l+1).size_time = size(res(l+1).x,3);
    
    if l~=numLayers || net.FC==false
        res(l+1).x = vl_nnrelu(res(l+1).x);
    end
end

res(l+1).indicator=gpuArray(ones(size(res(l+1).x),'single' ));

for l=numLayers:-1:1
    
    if l~=numLayers || net.FC==false
        res(l+1).indicator=vl_nnrelu(res(l+1).x, res(l+1).indicator);
    end
    
    [res(l).indicator, res(l+1).stat_weights, res(l+1).stat_bias]=mex_conv3d(res(l).x, net.layers{l}.filters, net.layers{l}.bias, res(l+1).indicator, 'pad', net.layers{l}.pad, 'stride', net.layers{l}.stride);
    
    if conserveMemory
        res(l+1).indicator = [] ;
        res(l+1).x = [] ;
    end
    
    res(l+1).stat_weights = gather(res(l+1).stat_weights * (1/numVideos));
    res(l+1).stat_bias = gather(res(l+1).stat_bias * (1/numVideos));
end

if conserveMemory
     res(1).indicator = [] ;
     res(1).x = [] ;
end






