function  [net_cpu, syn_mats] = infer_nfa(opts, getBatch, subset, imdb, net_cpu, syn_mats, config)
% -------------------------------------------------------------------------
% updating the weights in each epoch.

% move CNN to GPU as needed
numGpus = numel(opts.gpus) ;
net = vl_simplenn_move(net_cpu, 'gpu') ;

% validation mode if learning rate is zero
%training = learningRate > 0 ;
%if training, mode = 'training' ; else, mode = 'validation' ; end
%if nargout > 2, mpiprofile on ; end
mmap= [];


for t=1:opts.batchSize:numel(subset)
    %fprintf('batch_size %02d', opts.batchSize);
    %fprintf('%s: epoch %02d: batch %3d/%3d: ', mode, epoch, ...
    %    fix(t/opts.batchSize)+1, ceil(numel(subset)/opts.batchSize)) ;
    batchSize = min(opts.batchSize, numel(subset) - t + 1) ;
    batchTime = tic ;
    numDone = 0 ;
    res = [] ;
    res_syn = [];
%     stats = [] ;
    %   error = [] ;
    
    for s = 1:opts.numSubBatches    
        batchStart = t + (labindex-1) + (s-1) *numlabs;
        batchEnd = min(t+opts.batchSize-1, numel(subset)) ;
        batch = subset(batchStart : opts.numSubBatches*numlabs:  batchEnd) ;
        
        im = getBatch(imdb, batch) ;
        if opts.prefetch
            if s==opts.numSubBatches
                batchStart = t + (labindex-1) + opts.batchSize ;
                batchEnd = min(t+2*opts.batchSize-1, numel(subset)) ;
            else
                batchStart = batchStart + numlabs ;
            end
            nextBatch = subset(batchStart : opts.numSubBatches * numlabs : batchEnd) ;
            getBatch(imdb, nextBatch) ;
        end
        
        if numGpus >= 1
            im = gpuArray(im) ;
        end
        
        % training images
        %numImages = size(im, 4);
        cell_idx = (ceil(t / opts.batchSize) -1)*numlabs + labindex;
        fprintf('numlabs %2d, labindex %2d', numlabs, labindex);
        %slowness_mat = gpuArray(slowness_mats{cell_idx});
        syn_mat = gpuArray(syn_mats{cell_idx});
        if isempty(syn_mat)
           %syn_mat = gpuArray(config.refsig*randn([1, config.z_dim, 1, size(im, 4)], 'single')); 
           syn_mat = gpuArray(config.refsig*zeros([1,config.z_dim,1, size(im, 4)], 'single'));
        end
       
        
        switch config.alg_type
            % for now, only focus on langevin sampling
            case 'alter_grad'
                %[syn_mat] = alternate_gradient_z_slowness(config, net, im, syn_mat, slowness_mat);  
                [syn_mat] = alternate_gradient_z_test(config, net, im, syn_mat, slowness_mat);  
                syn_mats{cell_idx} = gather(syn_mat);      
                
  
            case 'joint_grad'
                net = vl_simplenn_move(net, 'gpu');
                fz = vl_simplenn(net, syn_mat);
                dydz = im - fz(end).x;
                res = vl_simplenn(net, gpuArray(syn_mat), gpuArray(dydz), res, 'conserveMemory', 1, 'cudnn', 1);
                syn_mat = syn_mat + config.joint_lambda  /config.s /config.s* res(1).dzdx ...
                          - config.joint_lambda /config.refsig /config.refsig* syn_mat;
                %syn_mat = syn_mat + 0.3 * gpuArray(randn(size(syn_mat), 'single'));
                
                syn_mats{cell_idx} = gather(syn_mat);
               
                
            case 'langevin_sampling'
                [syn_mat]= langevin_dynamic_z_test(config, net, im, syn_mat);              
                syn_mats{cell_idx} = gather(syn_mat);
                
              
        end
                
        numDone = numDone + numel(batch) ;
    end

    % gather and accumulate gradients across labs
  
    clear res;
    clear res_syn;
    
    % print learning statistics
    batchTime = toc(batchTime) ;
    %   stats = sum([stats,[batchTime ; error]],2); % works even when stats=[]
    speed = batchSize/batchTime ;
    
    fprintf(' %.2f s (%.1f data/s)', batchTime, speed) ;
    fprintf(' [%d/%d]', numDone, batchSize);
    fprintf('\n') ;
end
net_cpu = vl_simplenn_move(net, 'cpu') ;
end













 