import os
import re
import argparse
import numpy as np
import tensorflow as tf
import imageio
import scipy.misc
from utils import get_nmq_data_oneframe
from utils import *
from utils import map_to_color
from utils import map_to_scale
from utils import compute_similarity
from utils import compute_run_length_encoding_len
from flow_utils import vis_flow_pyramid
from flow_utils import vis_interpolated_imgs
from flow_utils import vis_flow
from model import PWCNet
import math
import cv2
import matplotlib.pyplot as plt




class Tester(object):
    def __init__(self, args):
        self.args = args
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config = config)
        self._build_graph()
        self.saver=tf.train.Saver()

    def _build_graph(self):
        img1_path, img2_path = self.args.input_images
        #img1, img2 = map(imageio.imread, (img1_path, img2_path))
        self.img=[]
        self.img_raw=[]
        self.img_ms=[]
        self.img_raw_ms=[]
        #img2=get_data_oneframe(args,int(img2_path))
        self.seq_length=int(img2_path)-int(img1_path)+1
        for i in range(self.seq_length):
            if(self.args.data_type=="vortex"):
                raw_frame=get_vol_data_oneframe(self.args,int(img1_path)+i)
                self.img.append(map_to_scale(raw_frame,self.args.lower_threshold,self.args.upper_threshold,s=args.data_scale))

            elif(self.args.data_type=="isabel"):
                raw_frame=get_isabel_data_oneframe(self.args,int(img1_path)+i)
                self.img.append(map_to_scale(raw_frame,self.args.lower_threshold,self.args.upper_threshold,s=args.data_scale))
                
            elif(self.args.data_type=="nmq"):
                raw_frame=get_nmq_data_oneframe(self.args,int(img1_path)+i)
                self.img.append(map_to_scale(raw_frame,self.args.lower_threshold,self.args.upper_threshold,s=args.data_scale))

            self.img_raw.append(np.stack([raw_frame for i in range(3)],axis=2))

            
            print('img loaded:'+str(int(img1_path)+i))


        scipy.misc.imsave('test_figure/start.png',self.img[0])
        scipy.misc.imsave('test_figure/end.png',self.img[self.seq_length-1])
        #img1=np.stack([img1,img1,img1],axis=2)
        #img2=np.stack([img2,img2,img2],axis=2)
        self.start_end = np.array([self.img[0], self.img[self.seq_length-1]]) # shape(2, h, w, 3)
        self.end_start = np.array([self.img[self.seq_length-1],self.img[0]])
        self.max_value=np.amax(self.start_end)
        
        self.indices_h=np.zeros((self.img[0].shape[0],self.img[0].shape[1]))
        self.indices_w=np.zeros((self.img[0].shape[0],self.img[0].shape[1]))
        for i in range(self.img[0].shape[0]):
            for j in range(self.img[0].shape[1]):
                self.indices_h[i,j]=i
                self.indices_w[i,j]=j

        self.flow_scale=self.args.flow_scale
        
        self.flow_theta=tf.get_variable("flow_theta",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.5))
        self.flow_r=tf.get_variable("flow_r",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.05),regularizer=tf.contrib.layers.l2_regularizer(scale=0.1))
        self.flow_theta_re=tf.get_variable("flow_theta_re",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.5))
        self.flow_r_re=tf.get_variable("flow_r_re",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.05),regularizer=tf.contrib.layers.l2_regularizer(scale=0.1))       

        self.flow_theta=tf.squeeze( tf.image.resize_images(tf.expand_dims(self.flow_theta,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))
        self.flow_r=tf.squeeze(tf.image.resize_images(tf.expand_dims(self.flow_r,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))
        self.flow_theta_re=tf.squeeze(tf.image.resize_images(tf.expand_dims(self.flow_theta_re,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))
        self.flow_r_re=tf.squeeze(tf.image.resize_images(tf.expand_dims(self.flow_r_re,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))
 
        d_kernel=tf.constant([ [ [[-1/8]],[[-1/8]],[[-1/8]] ],[ [[-1/8]],[[1]],[[-1/8]] ],[ [[-1/8]],[[-1/8]],[[-1/8]] ]  ])
 
        self.if_our_method=tf.placeholder(tf.bool)
        self.flow_input=tf.placeholder(tf.float32)
        self.flow_re_input=tf.placeholder(tf.float32)
        
        def our_method():
            flow_h=tf.sin(self.flow_theta*math.pi*2)*self.flow_r*self.img[0].shape[0]
            flow_w=tf.cos(self.flow_theta*math.pi*2)*self.flow_r*self.img[0].shape[1]
            flow_h_re=tf.sin(self.flow_theta_re*math.pi*2)*self.flow_r_re*self.img[0].shape[0]
            flow_w_re=tf.cos(self.flow_theta_re*math.pi*2)*self.flow_r_re*self.img[0].shape[1]
            return flow_h,flow_w,flow_h_re,flow_w_re
        def other_methods():

            flow_h=self.flow_input[:,:,0]
            flow_w=self.flow_input[:,:,1]
            flow_h_re=self.flow_re_input[:,:,0]
            flow_w_re=self.flow_re_input[:,:,1]


            return flow_h,flow_w,flow_h_re,flow_w_re

        self.flow_h,self.flow_w,self.flow_h_re,self.flow_w_re=tf.cond(self.if_our_method,our_method,other_methods)

        '''
        self.flow_h=tf.get_variable("flow_h",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        self.flow_w=tf.get_variable("flow_w",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1),regularizer=tf.contrib.layers.l2_regularizer(scale=0.1))
        self.flow_h_re=tf.get_variable("flow_h_re",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        self.flow_w_re=tf.get_variable("flow_w_re",[self.img[0].shape[0]//self.flow_scale,self.img[0].shape[1]//self.flow_scale],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1),regularizer=tf.contrib.layers.l2_regularizer(scale=0.1))       

        self.flow_h=tf.squeeze( tf.image.resize_images(tf.expand_dims(self.flow_h,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))
        self.flow_w=tf.squeeze(tf.image.resize_images(tf.expand_dims(self.flow_w,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))
        self.flow_h_re=tf.squeeze(tf.image.resize_images(tf.expand_dims(self.flow_h_re,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))
        self.flow_w_re=tf.squeeze(tf.image.resize_images(tf.expand_dims(self.flow_w_re,axis=2),[self.img[0].shape[0],self.img[0].shape[1]]))

        self.flow_h=self.flow_h*self.img[0].shape[0]
        self.flow_w=self.flow_w*self.img[0].shape[1]
        self.flow_h_re=self.flow_h_re*self.img[0].shape[0]
        self.flow_w_re=self.flow_w_re*self.img[0].shape[1]

        '''
        img0=tf.cast(self.img[0],dtype=tf.float32)
        img1=tf.cast(self.img[self.seq_length-1],dtype=tf.float32)
        img0_raw=tf.cast(self.img_raw[0],dtype=tf.float32)
        img1_raw=tf.cast(self.img_raw[self.seq_length-1],dtype=tf.float32)

        '''
        d_kernel=[]
        for i in range(9):
            if(i==4):
                continue
            a=i//3
            b=i-a*3
            kernel=np.zeros(shape=(3,3,1,1),dtype=np.float32)
            kernel[a,b,0,0]=-1
            kernel[1,1,0,0]=1
            d_kernel.append(kernel)
        #d_kernel=tf.constant([ [ [[-1]],[[-1]],[[-1]] ],[ [[-1]],[[8]],[[-1]] ],[ [[-1]],[[-1]],[[-1]] ]  ],dtype=tf.float32)
        '''

        self.flow_h_d=0
        self.flow_h_d_re=0
        self.flow_w_d=0
        self.flow_w_d_re=0
        #self.flow_theta=0
        '''
        for i in range(8):
            self.img_d=tf.abs(tf.nn.conv2d( [tf.expand_dims(tf.reduce_mean(img0,axis=2),axis=2)],d_kernel[i],strides=[1,1,1,1],padding='SAME'))
            self.img_d_re=tf.abs(tf.nn.conv2d( [tf.expand_dims(tf.reduce_mean(img1,axis=2),axis=2)],d_kernel[i],strides=[1,1,1,1],padding='SAME'))
            self.img_d=1/(self.img_d+1)
            self.img_d_re=1/(self.img_d_re+1)

            self.flow_h_d=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_h,axis=2)],d_kernel[i],strides=[1,1,1,1],padding='SAME'))/self.img_d + self.flow_h_d
            self.flow_w_d=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_w,axis=2)],d_kernel[i],strides=[1,1,1,1],padding='SAME'))/self.img_d + self.flow_w_d
            self.flow_h_d_re=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_h_re,axis=2)],d_kernel[i],strides=[1,1,1,1],padding='SAME'))/self.img_d_re + self.flow_h_d_re
            self.flow_w_d_re=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_w_re,axis=2)],d_kernel[i],strides=[1,1,1,1],padding='SAME'))/self.img_d_re + self.flow_w_d_re
        '''
        
        self.flow_h_d=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_h,axis=2)],d_kernel,strides=[1,1,1,1],padding='SAME'))#*self.flow_h
        self.flow_w_d=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_w,axis=2)],d_kernel,strides=[1,1,1,1],padding='SAME'))#*self.flow_w
        self.flow_h_d_re=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_h_re,axis=2)],d_kernel,strides=[1,1,1,1],padding='SAME'))#*self.flow_h_re
        self.flow_w_d_re=tf.abs(tf.nn.conv2d( [tf.expand_dims(self.flow_w_re,axis=2)],d_kernel,strides=[1,1,1,1],padding='SAME'))#*self.flow_w_re
        

        '''
        self.flow_h_d=tf.reduce_mean(tf.squeeze(self.flow_h_d)/8/tf.abs(self.flow_h))
        self.flow_h_d_re=tf.reduce_mean(tf.squeeze(self.flow_h_d_re)/8/tf.abs(self.flow_h_re))
        self.flow_w_d=tf.reduce_mean(tf.squeeze(self.flow_w_d)/8/tf.abs(self.flow_w))
        self.flow_w_d_re=tf.reduce_mean(tf.squeeze(self.flow_w_d_re)/8/tf.abs(self.flow_w_re))
        
        '''

        self.flow_h_d=tf.reduce_mean(tf.squeeze(self.flow_h_d)/8)
        self.flow_h_d_re=tf.reduce_mean(tf.squeeze(self.flow_h_d_re)/8)
        self.flow_w_d=tf.reduce_mean(tf.squeeze(self.flow_w_d)/8)
        self.flow_w_d_re=tf.reduce_mean(tf.squeeze(self.flow_w_d_re)/8)
        
        self.flow = tf.stack([self.flow_h,self.flow_w],axis=2)
        self.flow_re = tf.stack([self.flow_h_re,self.flow_w_re],axis=2)

        
        self.start_end_ms=[]
        self.direction_similarities=[]
        flow0_h=tf.constant(0,dtype=tf.float32)
        flow0_w=tf.constant(0,dtype=tf.float32)
        flow1_h=tf.constant(0,dtype=tf.float32)
        flow1_w=tf.constant(0,dtype=tf.float32)
        '''
        self.first_coeff=tf.get_variable("first_coeff",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        self.second_coeff=tf.get_variable("second_coeff",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        self.third_coeff=tf.get_variable("third_coeff",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        self.first_bias=tf.get_variable("first_bias",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        self.second_bias=tf.get_variable("second_bias",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        self.third_bias=tf.get_variable("third_bias",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1))
        '''
        #time_steps=np.arange(1/(self.seq_length-1),1,1/(self.seq_length-1))
        self.img_ms.append(img0)
        self.img_raw_ms.append(img0_raw)
        for m in range(self.seq_length):

            t=m/(self.seq_length-1)
           
            #tf.gather_nd(img0,tf.concat([lower0_h,lower0_w],axis=2)     
            #tf.gather_nd(img0,tf.concat([lower0_h,lower0_w],axis=2)  
            '''
            time_step=np.ones(shape=[self.img[0].shape[0],self.img[0].shape[1]])*t
            transformed_ratios=tf.sigmoid(tf.nn.relu(tf.nn.relu(t*self.first_coeff+self.first_bias)*self.second_coeff+self.second_bias)*self.third_coeff+self.third_bias)

            
            flow0_h=-(1-transformed_ratios)*transformed_ratios*self.flow_h+transformed_ratios*transformed_ratios*self.flow_h_re
            flow0_w=-(1-transformed_ratios)*transformed_ratios*self.flow_w+transformed_ratios*transformed_ratios*self.flow_w_re
            flow1_h=(1-transformed_ratios)*(1-transformed_ratios)*self.flow_h-transformed_ratios*(1-transformed_ratios)*self.flow_h_re
            flow1_w=(1-transformed_ratios)*(1-transformed_ratios)*self.flow_w-transformed_ratios*(1-transformed_ratios)*self.flow_w_re
            '''
            if m==0:
                flow1_h=self.flow_h
                flow1_w=self.flow_w
            elif m==self.seq_length-1:
                flow0_h=self.flow_h_re
                flow0_w=self.flow_w_re
            else:
                flow0_h_m=tf.square(self.flow_h*t)
                flow0_w_m=tf.square(self.flow_w*t)
                flow1_h_m=tf.square(self.flow_h_re*(1-t))
                flow1_w_m=tf.square(self.flow_w_re*(1-t))

                flow0_m=flow0_h_m+flow0_w_m
                flow1_m=flow1_h_m+flow1_w_m

                #ratio=tf.nn.softmax( 1/(tf.stack([flow0_m,flow1_m],axis=2)+1e-5),axis=2)
                ratio=flow1_m/(flow0_m+flow1_m+1e-9)
                '''
                flow0_h=-(1-t)*t*self.flow_h+t*t*self.flow_h_re
                flow0_w=-(1-t)*t*self.flow_w+t*t*self.flow_w_re
                flow1_h=(1-t)*(1-t)*self.flow_h-t*(1-t)*self.flow_h_re
                flow1_w=(1-t)*(1-t)*self.flow_w-t*(1-t)*self.flow_w_re
                '''
                flow0_h=-ratio[:,:]*t*self.flow_h+(1-ratio[:,:])*t*self.flow_h_re
                flow0_w=-ratio[:,:]*t*self.flow_w+(1-ratio[:,:])*t*self.flow_w_re
                flow1_h=ratio[:,:]*(1-t)*self.flow_h-(1-ratio[:,:])*(1-t)*self.flow_h_re
                flow1_w=ratio[:,:]*(1-t)*self.flow_w-(1-ratio[:,:])*(1-t)*self.flow_w_re
            
            flow0_h=tf.clip_by_value(self.indices_h+flow0_h,0,self.img[0].shape[0]-1)
            flow0_w=tf.clip_by_value(self.indices_w+flow0_w,0,self.img[0].shape[1]-1)
            flow1_h=tf.clip_by_value(self.indices_h+flow1_h,0,self.img[0].shape[0]-1)
            flow1_w=tf.clip_by_value(self.indices_w+flow1_w,0,self.img[0].shape[1]-1)

            #direction_similarity=.0
            #print(flow0_h.dtype)
            direction_similarity=flow0_h*flow1_h+flow0_w*flow1_w
            direction_similarity=direction_similarity/(tf.sqrt(tf.square(flow0_h)+tf.square(flow0_w))+1e-9)/(tf.sqrt(tf.square(flow1_h)+tf.square(flow1_w))+1e-9)

            lower0_h=tf.floor(flow0_h)
            upper0_h=tf.minimum(lower0_h+1,self.img[0].shape[0]-1)
            lower0_w=tf.floor(flow0_w)
            upper0_w=tf.minimum(lower0_w+1,self.img[0].shape[1]-1)

            lower1_h=tf.floor(flow1_h)
            upper1_h=tf.minimum(lower1_h+1,self.img[0].shape[0]-1)
            lower1_w=tf.floor(flow1_w)
            upper1_w=tf.minimum(lower1_w+1,self.img[0].shape[1]-1)

            u0=flow0_h-lower0_h
            v0=flow0_w-lower0_w
            u1=flow1_h-lower1_h
            v1=flow1_w-lower1_w

            u0=tf.cast(tf.stack([u0 for i in range(3)],axis=2),dtype=np.float32)
            v0=tf.cast(tf.stack([v0 for i in range(3)],axis=2),dtype=np.float32)
            u1=tf.cast(tf.stack([u1 for i in range(3)],axis=2),dtype=np.float32)
            v1=tf.cast(tf.stack([v1 for i in range(3)],axis=2),dtype=np.float32)

            img0_m=(1-v0)*((1-u0)*tf.gather_nd(img0, tf.cast(tf.stack([lower0_h,lower0_w],axis=2),dtype=tf.int32))+u0*tf.gather_nd(img0,tf.cast(tf.stack([upper0_h,lower0_w],axis=2),dtype=tf.int32)) ) + v0*((1-u0)*tf.gather_nd(img0,tf.cast(tf.stack([lower0_h,upper0_w],axis=2),dtype=tf.int32))+ u0*tf.gather_nd(img0,tf.cast(tf.stack([upper0_h,upper0_w],axis=2),dtype=tf.int32)))
            img1_m=(1-v1)*((1-u1)*tf.gather_nd(img1, tf.cast(tf.stack([lower1_h,lower1_w],axis=2),dtype=tf.int32))+u1*tf.gather_nd(img1,tf.cast(tf.stack([upper1_h,lower1_w],axis=2),dtype=tf.int32) ))+ v1*((1-u1)*tf.gather_nd(  img1,tf.cast(tf.stack([lower1_h,upper1_w],axis=2),dtype=tf.int32))+ u1*tf.gather_nd(img1,tf.cast(tf.stack([upper1_h,upper1_w],axis=2),dtype=tf.int32)))
  
            img0_raw_m=(1-v0)*((1-u0)*tf.gather_nd(img0_raw, tf.cast(tf.stack([lower0_h,lower0_w],axis=2),dtype=tf.int32))+u0*tf.gather_nd(img0_raw,tf.cast(tf.stack([upper0_h,lower0_w],axis=2),dtype=tf.int32)) ) + v0*((1-u0)*tf.gather_nd(img0_raw,tf.cast(tf.stack([lower0_h,upper0_w],axis=2),dtype=tf.int32))+ u0*tf.gather_nd(img0_raw,tf.cast(tf.stack([upper0_h,upper0_w],axis=2),dtype=tf.int32)))
            img1_raw_m=(1-v1)*((1-u1)*tf.gather_nd(img1_raw, tf.cast(tf.stack([lower1_h,lower1_w],axis=2),dtype=tf.int32))+u1*tf.gather_nd(img1_raw,tf.cast(tf.stack([upper1_h,lower1_w],axis=2),dtype=tf.int32) ))+ v1*((1-u1)*tf.gather_nd(img1_raw,tf.cast(tf.stack([lower1_h,upper1_w],axis=2),dtype=tf.int32))+ u1*tf.gather_nd(img1_raw,tf.cast(tf.stack([upper1_h,upper1_w],axis=2),dtype=tf.int32)))
  
            if (m==0):
                img_m=img1_m
                self.start_end_ms.append(img_m)
            elif (m==self.seq_length-1):
                img_m=img0_m
                self.start_end_ms.append(img_m)
            else :
                img_m=(1-t)*img0_m+t*img1_m
                img_raw_m=(1-t)*img0_raw_m+t*img1_raw_m
                self.img_ms.append(img_m)
                self.img_raw_ms.append(img_raw_m)
                self.direction_similarities.append(direction_similarity)

            #img_m=(1-tf.stack([transformed_ratios,transformed_ratios,transformed_ratios],axis=2))*img0_m+tf.stack([transformed_ratios,transformed_ratios,transformed_ratios],axis=2)*img1_m
            #img_m=(1-transformed_ratios)*img0_m+transformed_ratios*img1_m
            
        self.direction_similarities=tf.reduce_mean(self.direction_similarities)
        self.img_ms.append(img1)
        self.img_raw_ms.append(img1_raw)
        original_imgs=tf.convert_to_tensor(np.array(self.img),dtype=tf.float32)
        #self.mean_err=tf.reduce_mean(tf.abs(tf.stack(self.img_ms,axis=0)-original_imgs)/(original_imgs+.1))
        self.err= tf.abs(tf.stack(self.img_ms,axis=0)-original_imgs)
        self.mean_err_start_end= tf.abs(tf.stack(self.start_end_ms,axis=0)-tf.convert_to_tensor(self.start_end, dtype = tf.float32))/self.max_value
        #self.diff_start_end_d = tf.abs(tf.nn.conv2d( tf.expand_dims(tf.reduce_mean(self.mean_err_start_end,axis=3),axis=3),d_kernel,strides=[1,1,1,1],padding='SAME'))
        #self.diff_start_end_d = tf.reduce_mean(self.diff_start_end_d)

        self.temporal_interval=self.args.temporal_interval
        self.errs=[]
        for i in range(0,self.seq_length,self.temporal_interval):
            self.errs.append(self.err[i,:,:])

        self.mean_err=tf.reduce_mean(self.errs)

        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        reg_constant = 0.001  # Choose an appropriate one.
        self.alpha=tf.abs(tf.get_variable("alpha",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1)))
        self.beta=tf.abs(tf.get_variable("beta",shape=[],initializer=tf.truncated_normal_initializer(mean=0,stddev=.1)))
        if(self.args.data_type=="nmq" or self.args.data_type=="vortex" ):
            #self.loss =(self.flow_h_d + self.flow_h_d_re+self.flow_w_d + self.flow_w_d_re)/4*20#+ tf.reduce_mean(self.mean_err_start_end) + 2*self.direction_similarities #+ self.diff_start_end_d*self.beta  ##+ reg_constant * sum(reg_losses)
            #self.loss = self.mean_err*1e20 + (self.flow_h_d + self.flow_h_d_re+self.flow_w_d + self.flow_w_d_re)/4#+ .5*self.mean_err_start_end + .5*self.direction_similarities #+ self.diff_start_end_d*self.beta  ##+ reg_constant * sum(reg_losses)
            self.loss=tf.abs(tf.reduce_mean(self.flow_h))+tf.abs(tf.reduce_mean(self.flow_w))+tf.abs(tf.reduce_mean(self.flow_h_re))+tf.abs(tf.reduce_mean(self.flow_w_re))
        elif(self.args.data_type=="isabel"):
            self.loss =  (self.flow_h_d + self.flow_h_d_re+self.flow_w_d + self.flow_w_d_re)/4*8#+ self.mean_err_start_end + self.direction_similarities
            #self.loss = self.mean_err*30 + (self.flow_h_d + self.flow_h_d_re+self.flow_w_d + self.flow_w_d_re)/4*.5+ .5*self.mean_err_start_end + .5*self.direction_similarities



        #self.loss = self.mean_err + (self.flow_h_d + self.flow_h_d_re+self.flow_w_d + self.flow_w_d_re)/4*16+ self.mean_err_start_end + 2*self.direction_similarities 
        #np.where(np.array(img_ms)>0,np.array(img_ms)*255)
        #mean=np.mean(np.array(self.img))
        print('Graph built')
        self.optimizer = tf.train.AdamOptimizer(learning_rate = self.args.lr)\
                         .minimize(self.loss)

    
    def test(self):
        '''
        print(flow)
        print(len(np.where(flow > 5)[0]))
        flow[np.where(flow > 1e9)]=0
        print(flow.shape)
        img0=self.images[0]
        img1=self.images[1]
        print(img0.shape)
        img0_m=np.zeros((img1.shape[0],img1.shape[1],3))
        print(img0_m.shape)
        img1_m=np.zeros((img1.shape[0],img1.shape[1],3))
        '''

        if not os.path.exists('./test_figure'):
            os.mkdir('./test_figure')
        #fname = '_'.join(re.split('[/.]', self.args.input_images[0])[-3:-1])

        if(self.args.method=="particle_train"):#training
            print("Learning")
            self.sess.run(tf.global_variables_initializer())
            for i in range(self.args.n_epoch):
                _,mean_err_result,alpha_result,beta_result=self.sess.run([self.optimizer, self.loss,self.alpha,self.beta],feed_dict={self.if_our_method:True,self.flow_input:0,self.flow_re_input:0})
                if(i%100==0):
                    print("Mean err of "+ str(i) +" :"+str(mean_err_result))
                    #print("Alpha: "+str(alpha_result))
                    #print("Beta: "+str(beta_result))
            save_path=self.saver.save(self.sess,"test_figure/model/particle_flow.ckpt")
            mean_err_result,flow_result,flow_re_result,alpha_result,beta_result=self.sess.run([self.mean_err,self.flow,self.flow_re,self.alpha,self.beta],feed_dict={self.if_our_method:True,self.flow_input:0,self.flow_re_input:0}) 
            img_ms_results,img_raw_ms_results=self.sess.run([self.img_ms,self.img_raw_ms],feed_dict={self.if_our_method:True,self.flow_input:0,self.flow_re_input:0}) 
            
            



        elif(self.args.method=="particle_test"): #testing
            print("Loading existing model")
            self.saver.restore(self.sess,"test_figure/model/particle_flow.ckpt")

            mean_err_result,flow_result,flow_re_result,alpha_result,beta_result=self.sess.run([self.mean_err,self.flow,self.flow_re,self.alpha,self.beta],feed_dict={self.if_our_method:True,self.flow_input:0,self.flow_re_input:0})  
            img_ms_results,img_raw_ms_results=self.sess.run([self.img_ms,self.img_raw_ms],feed_dict={self.if_our_method:True,self.flow_input:0,self.flow_re_input:0}) 
            
            print(flow_result.shape)
            print(img_raw_ms_results[0].shape)
            width,height,channel=img_raw_ms_results[0].shape
            data_coding_len=0
            data_start_end_coding_len=compute_run_length_encoding_len(img_raw_ms_results[0])+compute_run_length_encoding_len(img_raw_ms_results[-1])
            flow_coding_len=compute_run_length_encoding_len(flow_result)
            flow_coding_len+=compute_run_length_encoding_len(flow_re_result)
            for raw_image in img_raw_ms_results:
                data_coding_len+=compute_run_length_encoding_len(raw_image)

            run_length_encoding_ratio=data_coding_len/len(img_raw_ms_results)/width/height/channel
            flow_encoding_ratio=(flow_coding_len+data_start_end_coding_len)/len(img_raw_ms_results)/width/height/channel
            print(run_length_encoding_ratio)
            print(flow_encoding_ratio) 

            plt.bar(range(width),img_raw_ms_results[0][20,:,0])
            plt.show()
            


        elif(self.args.method=="GF"):
            print("Gunner Farneback")
            img1, img2 = map(imageio.imread, ('test_figure/start.png', 'test_figure/end.png'))
            img1=img1[:,:,0:3]
            img2=img2[:,:,0:3]
            flow = cv2.calcOpticalFlowFarneback(cv2.cvtColor(img1,cv2.COLOR_RGB2GRAY),cv2.cvtColor(img2,cv2.COLOR_RGB2GRAY), None, 0.5, 3, 15, 3, 5, 1.2, 0)
            flow_re = cv2.calcOpticalFlowFarneback(cv2.cvtColor(img2,cv2.COLOR_RGB2GRAY),cv2.cvtColor(img1,cv2.COLOR_RGB2GRAY), None, 0.5, 3, 15, 3, 5, 1.2, 0)       
            
            self.sess.run(tf.global_variables_initializer())
            mean_err_result,flow_result,flow_re_result,alpha_result,beta_result=self.sess.run([self.mean_err,self.flow,self.flow_re,self.alpha,self.beta],feed_dict={self.if_our_method:False,self.flow_input:flow,self.flow_re_input:flow_re}) 
            img_ms_results,img_raw_ms_results=self.sess.run([self.img_ms,self.img_raw_ms],feed_dict={self.if_our_method:False,self.flow_input:flow,self.flow_re_input:flow_re}) 
            

        elif(self.args.method=="PWC"):
            print("PWCNet")
            self.sess.run(tf.global_variables_initializer())
            #self.saver.restore(self.sess,"test_figure/model/particle_flow.ckpt")
            #flow=np.fromfile("test_figure/flow",dtype=)
            #flow_re=np.fromfile("test_figure/flow_re")
            flow=np.load("test_figure/flow.npy")
            flow_re=np.load("test_figure/flow_re.npy")
            mean_err_result,flow_result,flow_re_result,alpha_result,beta_result=self.sess.run([self.mean_err,self.flow,self.flow_re,self.alpha,self.beta],feed_dict={self.if_our_method:False,self.flow_input:flow,self.flow_re_input:flow_re}) 
            img_ms_results,img_raw_ms_results=self.sess.run([self.img_ms,self.img_raw_ms],feed_dict={self.if_our_method:False,self.flow_input:flow,self.flow_re_input:flow_re}) 
            

        elif(self.args.method=="FIFO"):
            print("Fade-In-Fade-Out")
            self.sess.run(tf.global_variables_initializer())
            #self.saver.restore(self.sess,"test_figure/model/particle_flow.ckpt")
            #flow=np.fromfile("test_figure/flow",dtype=)
            #flow_re=np.fromfile("test_figure/flow_re")
            flow=np.zeros(shape=(self.img[0].shape[0],self.img[0].shape[1],2))
            flow_re=np.zeros(shape=(self.img[0].shape[0],self.img[0].shape[1],2))
            mean_err_result,flow_result,flow_re_result,alpha_result,beta_result=self.sess.run([self.mean_err,self.flow,self.flow_re,self.alpha,self.beta],feed_dict={self.if_our_method:False,self.flow_input:flow,self.flow_re_input:flow_re}) 
            img_ms_results,img_raw_ms_results=self.sess.run([self.img_ms,self.img_raw_ms],feed_dict={self.if_our_method:False,self.flow_input:flow,self.flow_re_input:flow_re}) 
            
    
        s=self.args.data_scale
        l=self.args.lower_threshold
        r=self.args.upper_threshold
        mses,ssims,psnrs=compute_similarity(self.img_raw,img_raw_ms_results)
        mse_avg=np.mean(np.array(mses[1:-1]))
        ssim_avg=np.mean(np.array(ssims[1:-1]))
        psnr_avg=np.mean(np.array(psnrs[1:-1]))

        #mses,ssims,psnrs=compute_similarity(self.img,img_ms_results)
      
        print("mse: "+str(mses))
        print("ssim: "+str(ssims))
        print("psnr: "+str(psnrs))
        print("mse_avg: "+str(mse_avg))
        print("ssim_avg: "+str(ssim_avg))
        print("psnr_avg: "+str(psnr_avg))
        print(img_ms_results[0].shape)


        if(self.args.data_type=="vortex"):
            bins=np.arange(0,7.5,0.5)
        elif(self.args.data_type=="isabel"):
            bins=np.arange(0,0.017,0.001)


        hist_interpolated,_=np.histogram(np.array(img_raw_ms_results)[:,:,:,0].flatten(),bins=bins,density=True)
        hist_original,_=np.histogram(np.array(self.img_raw)[:,:,:,0].flatten(),bins=bins,density=True)

        if(self.args.data_type=="vortex"):
            hist_interpolated=hist_interpolated*.5
            hist_original=hist_original*.5
        elif(self.args.data_type=="isabel"):
            hist_interpolated=hist_interpolated*0.001
            hist_original=hist_original*0.001

       
        

        KS=np.amax(hist_interpolated-hist_original)
        KL=np.sum(np.log(hist_interpolated/(hist_original+1e-5)+1e-5)*hist_interpolated)
        BD= -np.log( np.sum(np.sqrt(hist_interpolated*hist_original)))

        print("Kolmogorov–Smirnov sistance: "+str(KS))
        print("Kullback–Leibler divergence: "+str(KL))
        print("Bhattacharyya distance: "+str(BD))
        print(str(np.sum(hist_interpolated)))
        #print(hist_interpolated)

        plt.title("Histogram of all the data frames")
        plt.subplot(211)
        plt.hist(np.array(img_raw_ms_results)[:,:,:,0].flatten(),bins=bins,density=True)
     

        plt.subplot(212)
        plt.hist(np.array(self.img_raw)[:,:,:,0].flatten(),bins=bins,density=True)
        
        plt.show()
       
        '''
        vis_flow_pyramid([flow_result,flow_re_result], images = [map_to_color(self.img[0][:,:,0]/s,l,r)/255,map_to_color(self.img[self.seq_length-1][:,:,0]/s,l,r)/255], filename = f'./test_figure/flows.pdf')
        vis_interpolated_imgs(images = [map_to_color(img_ms_results[i][:,:,0]/s,l,r)/255 for i in range(self.seq_length)], filename = f'./test_figure/interpolated.pdf')
        vis_interpolated_imgs(images = [map_to_color(self.img[i][:,:,0]/s,l,r)/255 for i in range(self.seq_length)], filename = f'./test_figure/original.pdf')
        print(flow_result)
        print(len(np.where(flow_result > 5)[0]))
        '''
        img1_path, img2_path = self.args.input_images
        for i in range(int(img2_path)-int(img1_path)+1):
            img_raw_ms_results_bytes=np.array(img_raw_ms_results)[i,:,:,0].tobytes()
            f=open("test_figure/"+self.args.data_type+"_"+"{:03}".format(i+int(img1_path))+".dat",'wb')
            f.write(img_raw_ms_results_bytes)
            f.close()
        
        image_folder = 'test_figure'
        video_name = 'video.avi'
        height, width, layers = self.img[0].shape
        '''
        for i in range(self.seq_length):
            scipy.misc.imsave('test_figure/imgs/nmq_'+"%02d"%i+'.png',map_to_color(img_ms_results[i][:,:,0]/s,l,r))
        for i in range(self.seq_length):
            os.system("ffmpeg -r 5  -i test_figure/imgs/nmq_%02d.png -vcodec mpeg4 -y test_figure/movie.mp4")
        '''
               

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_images', type = str, nargs = 2, required = True,
                        help = 'Target images (required)')
    parser.add_argument('--resume', type =  str, default = None,
                        help = 'Learned parameter checkpoint file [None]')
    parser.add_argument('--lr', type = float, default = 1e-4,
                        help = 'Learning rate [1e-4]')
    parser.add_argument('--n_epoch', type = int, default = 3000,
                        help = '# of epochs [900]')
    parser.add_argument('--lower_threshold', type = float, default = 0.009,
                        help = 'lower threshold of data')
    parser.add_argument('--upper_threshold', type = float, default = 0.017,
                        help = 'upper threshold of data preprocessing')
    parser.add_argument('--data_scale', type = float, default = 30,
                        help = 'scale of data preprocessing')
    parser.add_argument('--data_type', type = str, default = "vortex",
                        help = 'type of data')
    parser.add_argument('--flow_scale', type = int, default = 1,
                        help = 'resolution of paritcle flow ')
    parser.add_argument('--method', type = str , default = "particle_train",
                        help ='type of method ')
    parser.add_argument('--temporal_interval', type = int , default = 1,
                        help ='temporal interval for err calculation ')
    args = parser.parse_args()
    for key, item in vars(args).items():
        print(f'{key} : {item}')

    #os.environ['CUDA_VISIBLE_DEVICES'] = input('Input utilize gpu-id (-1:cpu) : ')

    tester = Tester(args)
    tester.test()
