from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import tensorflow as tf
import numpy as np
import os
from .ops import *
from .util import *
from .optim import GradientDescent

class STGConvnet(object):
    def __init__(self, sess, config):
        self.sess = sess
        self.batch_size = config.batch_size
        self.image_size = config.image_size
        self.num_frames = config.num_frames
        self.num_chain = config.num_chain
        self.num_epochs = config.num_epochs

        self.step_size = config.step_size
        self.sample_steps = config.sample_steps

        self.category = config.category
        self.data_path = os.path.join(config.data_path, config.category)
        self.log_step = config.log_step
        self.output_dir = os.path.join(config.output_dir, config.category)

        self.log_dir = os.path.join(self.output_dir, 'log')
        self.train_dir = os.path.join(self.output_dir, 'observed_sequence')
        self.sample_dir = os.path.join(self.output_dir, 'synthesis_sequence')
        self.model_dir = os.path.join(self.output_dir, 'model')
        self.result_dir = os.path.join(self.output_dir, 'final_result')

        if tf.gfile.Exists(self.log_dir):
            tf.gfile.DeleteRecursively(self.log_dir)
        tf.gfile.MakeDirs(self.log_dir)

        self.syn = tf.placeholder(shape=[self.num_chain, self.num_frames, self.image_size, self.image_size, 3], dtype=tf.float32)
        self.obs = tf.placeholder(shape=[None, self.num_frames, self.image_size, self.image_size, 3], dtype=tf.float32)

    def descriptor(self, inputs, reuse=False):
        layers = []
        with tf.variable_scope('des', reuse=reuse):
            conv1 = conv3d(inputs, 120, (7, 7, 7), strides=(2, 2, 2), padding="SAME", activation_fn=tf.nn.relu, name="conv1")
            layers.append({'output': conv1, 'lr': 0.01,
                           'vars':[var for var in tf.trainable_variables() if 'conv1' in var.name]})

            conv2 = conv3d(conv1, 30, (5, 50, 50), strides=(2, 1, 1), padding=(2, 0, 0), activation_fn=tf.nn.relu, name="conv2")
            layers.append({'output': conv2, 'lr': 0.001,
                           'vars':[var for var in tf.trainable_variables() if 'conv2' in var.name]})

            conv3 = conv3d(conv2, 5, (2, 1, 1), strides=(1, 2, 2), padding=(1, 0, 0), activation_fn=tf.nn.relu, name="conv3")
            layers.append({'output': conv3, 'lr': 0.0001,
                           'vars':[var for var in tf.trainable_variables() if 'conv3' in var.name]})
            return conv3, layers

    def langevin_dynamics(self, samples, gradient):
        for i in range(self.sample_steps):
            noise = np.random.randn(*samples.shape)
            grad = self.sess.run(gradient, feed_dict = {self.syn: samples})
            samples = samples - 0.5 * self.step_size * self.step_size * (samples - grad) + self.step_size * noise
        return samples

    def train(self, debug=True):

        obs_res, layers = self.descriptor(self.obs, reuse=False)
        syn_res, _ = self.descriptor(self.syn, reuse=True)
        train_loss = tf.subtract(tf.reduce_mean(syn_res,axis=0), tf.reduce_mean(obs_res,axis=0))
        train_loss_mean, train_loss_update = tf.contrib.metrics.streaming_mean(train_loss)

        recon_err_mean, recon_err_update = tf.contrib.metrics.streaming_mean_squared_error(
            tf.reduce_mean(self.syn,axis=0),tf.reduce_mean(self.obs,axis=0))

        dLdI = tf.gradients(syn_res, self.syn)[0]

        # Prepare training data
        loadVideoToFrames(self.data_path, self.train_dir)
        train_data = getTrainingData(self.train_dir, num_frames=self.num_frames, image_size=self.image_size)
        img_mean = train_data.mean()
        train_data = train_data - img_mean
        print('Working on training video {}, shape: {}'.format(self.category, train_data.shape))

        num_batches = int(math.ceil(len(train_data) / self.batch_size))

        optim = GradientDescent(train_loss, layers, num_batches=num_batches, cap=40)

        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())

        sample_size = self.num_chain * num_batches
        sample_video = np.random.randn(sample_size, self.num_frames, self.image_size, self.image_size, 3).astype(float)

        tf.summary.scalar('train_loss', train_loss_mean)
        tf.summary.scalar('recon_err', recon_err_mean)
        summary_op = tf.summary.merge_all()

        saver = tf.train.Saver(max_to_keep=50)
        writer = tf.summary.FileWriter(self.log_dir, self.sess.graph)

        for epoch in range(self.num_epochs):


            self.sess.run(optim.reset_gradient)
            for i in range(num_batches):
                obs_data = train_data[i * self.batch_size:min(len(train_data), (i+1) * self.batch_size)]
                syn = sample_video[i * self.num_chain:(i+1) * self.num_chain]

                syn = self.langevin_dynamics(syn, dLdI)

                self.sess.run([optim.update_gradient, train_loss_update], feed_dict={self.obs: obs_data, self.syn: syn})

                self.sess.run(recon_err_update, feed_dict={self.obs: obs_data, self.syn: syn})

                sample_video[i * self.num_chain:(i + 1) * self.num_chain] = syn

            ssd_weight, maximum_activations = self.sess.run([optim.apply_gradient, optim.ssd_weight, optim.maximum_activations])[1:3]
            if debug:
                print('##############################################################################')
                for i in range(len(ssd_weight)):
                    print('Layer {}, ssd_weight: {} maximum_num_activations: {:d}'.format(i, np.mean(ssd_weight[i]), maximum_activations[i]))
            [loss, recon_err, summary] = self.sess.run([train_loss_mean, recon_err_mean, summary_op])
            print('Epoch #%d, descriptor loss: %.4f, Avg MSE: %4.4f' % (epoch, loss, recon_err))
            writer.add_summary(summary, epoch)

            if epoch % self.log_step == 0:
                if not os.path.exists(self.sample_dir):
                    os.makedirs(self.sample_dir)
                saveSampleSequence(sample_video + img_mean, self.sample_dir, epoch, col_num=10)

                if not os.path.exists(self.model_dir):
                    os.makedirs(self.model_dir)
                saver.save(self.sess, "%s/%s" % (self.model_dir, 'model.ckpt'), global_step=epoch)

            if epoch % 50 == 0:
                saveSampleVideo(sample_video + img_mean, self.result_dir, original=(train_data + img_mean), global_step=epoch)

