import tensorflow as tf

def clip_by_abs(v, clip_v):
    max_val = tf.reduce_max(tf.abs(v))
    return v / max_val * tf.minimum(max_val, clip_v)

class GradientDescent(object):
    def __init__(self, loss, layers, num_batches, cap=40):
        self.max_gradient = cap
        self.loss = loss
        self.states = []
        self.num_batches = num_batches
        for layer in layers:
            self.states.append({'optim': tf.train.GradientDescentOptimizer(layer['lr']), 'vars': layer['vars'], 'lr': layer['lr'],
                                'maximum_activations': tf.reduce_prod(layer['output'].get_shape()[1:4])})

        for state in self.states:
            state['accum_vars'] = [tf.Variable(tf.zeros_like(var.initialized_value()), trainable=False) for var in state['vars']]
            state['reset_op'] = [var.assign(tf.zeros_like(var)) for var in state['accum_vars']]
            grads_and_vars = state['optim'].compute_gradients(loss, var_list=state['vars'])
            state['update_op'] = [state['accum_vars'][i].assign_add(gv[0]) for i, gv in enumerate(grads_and_vars)]
            state['avg_grads'] = [tf.divide(var, self.num_batches) for var in state['accum_vars']]
            state['ssd_weight'] = [tf.reduce_mean(tf.abs(g)) for g, v in zip(state['avg_grads'], state['vars']) if 'w' in v.name]
            state['clipped_grads'] = [clip_by_abs(g / tf.cast(state['maximum_activations'], tf.float32), self.max_gradient) for g in state['avg_grads']]
            state['apply_op'] = state['optim'].apply_gradients(zip(state['clipped_grads'], state['vars']))

    @property
    def maximum_activations(self):
        return [state['maximum_activations'] for state in self.states]

    @property
    def clipped_avg_gradients(self):
        return [state['clipped_grads'] for state in self.states]

    @property
    def ssd_weight(self):
        return [state['ssd_weight'] for state in self.states]

    @property
    def reset_gradient(self):
        return [state['reset_op'] for state in self.states]

    @property
    def update_gradient(self):
        return [state['update_op'] for state in self.states]

    @property
    def apply_gradient(self):
        return [state['apply_op'] for state in self.states]
