import sys
import tensorflow as tf
import numpy as np
import csv
import io
from netCDF4 import Dataset
from subprocess import call


import cv2
from subprocess import call
import struct
from skimage.measure import compare_ssim as ssim
import scipy.misc as misc

def get_grid(x):
    batch_size, height, width, filters = tf.unstack(tf.shape(x))
    Bg, Yg, Xg = tf.meshgrid(tf.range(batch_size), tf.range(height), tf.range(width),
                             indexing = 'ij')
    # return indices volume indicate (batch, y, x)
    # return tf.stack([Bg, Yg, Xg], axis = 3)
    return Bg, Yg, Xg # return collectively for elementwise processing

def show_progress(epoch, batch, batch_total, loss, epe):
    sys.stdout.write(f'\r{epoch} epoch: [{batch}/{batch_total}, loss: {loss}, epe: {epe}]')
    sys.stdout.flush()

def get_nmq_data_oneframe(args,frame_id=-1):
    
    scale = 20
    threshold= .5
    upper_threshold=15
    data_width=7000
    data_height=3500
    actual_width=data_width//scale
    actual_height=data_height//scale
    data_dir= "../200912"
    filenames=open(data_dir+'/nmq.txt','r')
    frame_interval=2
    #oneframe_s=np.zeros((actual_height,actual_width,3))*255
    i=0#the actual frame id
    j=0#the frame id after sampling
    for line in filenames:

            if(i% int(frame_interval) !=0):
                    i=i+1
                    continue
            if(j<frame_id):
                j=j+1
                i=i+1
                continue
            if(j>frame_id):
                break
            #print(str(i))
            #print(str(j))
            filename=line[0:-1]
            ncFile = data_dir+'/PRECIPRATE_HSR/'+filename
            #ncFile_cp = '/mnt/raid5/code/201002/'+gzSplit[1]+'-'+gzSplit[2]+'.netcdf'
            #print binFileName, timeIndex, ncFile
            #print("onelabel:"+str(frame_id))
            if(ncFile[-1]=='z'):
                    call(['gzip', '-df', ncFile])
                    setAll = Dataset(ncFile[0:-3],'r')
            else:
                    setAll = Dataset(ncFile,'r')
            oneframe = setAll.variables['PRECIPRATE_HSR'][:,:]
            #oneframe_s = cv2.imread("usa_map.png")
            
            oneframe_s = oneframe[::scale,::scale]
            oneframe_s = np.where(oneframe_s>=0.01, oneframe_s,0)
            print(len(np.where(oneframe_s>0)[0]))
            #oneframe_s=np.where(oneframe_s>=threshold,oneframe_s,0)
            #oneframe_s=np.where(oneframe_s<=upper_threshold,oneframe_s,upper_threshold)

            mean=np.mean(oneframe_s[np.where(oneframe_s>=0)])
            mu=np.sqrt(np.mean(np.square(oneframe_s[np.where(oneframe_s>=0)]-mean)))
            print("mean: "+str(mean))
            print("mu: "+str(mu))
            

            #oneframe_s = np.clip(oneframe_s,threshold,upper_threshold)
            #oneframe_s=(oneframe_s-threshold)*(255-20)/(upper_threshold-threshold)+20
            

            #oneframe_s=cmap(((oneframe_s-0)*(255-0)/(upper_threshold-0)+0).astype(np.int))*255
            
            #oneframe_s=np.where(oneframe_s==[[[255,255,255]]],0,oneframe_s )
            '''
            for x in range(actual_height):
                    for y in range(actual_width):
                            if(oneframe[x*scale][y*scale]>=float(threshold) and oneframe[x*scale][y*scale]<=upper_threshold):
                                   	oneframe_s[x][y]=tuple(i*255 for i in cmap(int((oneframe[x*scale][y*scale]-threshold)*(255-20)/(upper_threshold-threshold)+20)))[0:3]
                                   	#oneframe_s[x][y]=(int((oneframe[x*scale][y*scale]-threshold)*(255-20)/(upper_threshold-threshold)+20),)*3
                                   	#print(str(oneframe_s[x][y]))
                            elif(oneframe[x*scale][y*scale]>upper_threshold):
                                    oneframe_s[x][y]=tuple(i*255 for i in cmap(255))[0:3]
                                    #oneframe_s[x][y]=(255,)*3
                            #else:
                                    #oneframe_s[x][y]=tuple(i*255 for i in cmap(0))[0:3]
            '''
            setAll.close()
            i=i+1
            j=j+1
    filenames.close()
    #oneframe_s=map_to_color(oneframe_s,0,upper_threshold)
    #oneframe_s=map_to_scale(oneframe_s,threshold,upper_threshold,s=args.data_scale)
    return oneframe_s



def get_vol_data_oneframe(args,frame_id=-1):
    
    scale = 1
    threshold= 3
    upper_threshold=7.5
    data_width=128
    data_height=128
    actual_width=data_width//scale
    actual_height=data_height//scale
    data_dir= "../volume_rendering"
    filenames=open(data_dir+'/vorts_slice.txt','r')
    frame_interval=1
    s=2
    i=0#the actual frame id
    j=0#the frame id after sampling
    for line in filenames:

            if(i% int(frame_interval) !=0):
                    i=i+1
                    continue
            if(j<frame_id):
                j=j+1
                i=i+1
                continue
            if(j>frame_id):
                break
            filename=line[0:-1]
            filename = data_dir+'/vorts_slice/'+filename
            oneframe=np.zeros(shape=(data_height,data_width),dtype='float')
            framefile=open(filename,'rb')
            for i in range(data_height):
                 for j in range(data_width):
                     oneframe[i][j]=struct.unpack('f',framefile.read(4))[0]
            
            oneframe_s = oneframe[::scale,::scale]
            mean=np.mean(oneframe_s[np.where(oneframe_s>=0)])
            mu=np.sqrt(np.mean(np.square(oneframe_s[np.where(oneframe_s>=0)]-mean)))
            maximum=np.amax(oneframe_s)
           
            print("mean: "+str(mean))
            print("mu: "+str(mu))
            print("max: "+ str(maximum))
            
            #upper_threshold=mean+mu*s
            #threshold=mean/3
            '''
            oneframe_s=np.where(oneframe_s>=threshold,oneframe_s,0)
            oneframe_s=np.where(oneframe_s<=upper_threshold,oneframe_s,upper_threshold)
            '''
            
            i=i+1
            j=j+1
    filenames.close()
    framefile.close()
    #oneframe_s=map_to_scale(oneframe_s,threshold,upper_threshold,s=args.data_scale)
    return oneframe_s


def get_isabel_data_oneframe(args,frame_id=-1):
    
    scale = 1
    threshold= args.lower_threshold
    upper_threshold= args.upper_threshold
    data_width=500
    data_height=500
    actual_width=data_width//scale
    actual_height=data_height//scale
    data_dir= "../volume_rendering"
    filenames=open(data_dir+'/isabel_slice.txt','r')
    frame_interval=1
    s=2
    i=0#the actual frame id
    j=0#the frame id after sampling
    for line in filenames:

            if(i% int(frame_interval) !=0):
                    i=i+1
                    continue
            if(j<frame_id):
                j=j+1
                i=i+1
                continue
            if(j>frame_id):
                break
            filename=line[0:-1]
            filename = data_dir+'/isabel_slice/'+filename
            oneframe=np.zeros(shape=(data_height,data_width),dtype='float')
            framefile=open(filename,'rb')
            for i in range(data_height):
                 for j in range(data_width):
                     oneframe[i][j]=struct.unpack('f',framefile.read(4))[0]
            
            oneframe_s = oneframe[::scale,::scale]
            oneframe_s = np.where(oneframe_s>=0.001, oneframe_s,0)
            mean=np.mean(oneframe_s[np.where(oneframe_s>=0)])
            mu=np.sqrt(np.mean(np.square(oneframe_s[np.where(oneframe_s>=0)]-mean)))
            maximum=np.amax(oneframe_s)
            print("mean: "+str(mean))
            print("mu: "+str(mu))
            print("max: "+ str(maximum))
            #upper_threshold=mean+mu*s
            #threshold=mean/3
            '''
            oneframe_s=np.where(oneframe_s>=threshold,oneframe_s,0)
            oneframe_s=np.where(oneframe_s<=upper_threshold,oneframe_s,upper_threshold)
            '''
            
            i=i+1
            j=j+1
    filenames.close()
    framefile.close()
    #oneframe_s=map_to_scale(oneframe_s,threshold,upper_threshold,s=args.data_scale)
    return oneframe_s








def map_to_color(frames,l,r):
    c = colormap.Colormap()
    cmap = c.cmap_linear( 'black','white', 'red')
    frames=(frames-l)*(255-80)/(r-l)+80
    frames=np.where(frames<=80,0,frames)
    frames=np.where(frames>=255,255,frames)
    frames=cmap(frames.astype(np.int))*255
    frames = frames[...,0:3]
    return frames

def map_to_scale(frames,l,r,s):
    frames=np.stack([frames for i in range(3)],axis=2)
    #frames=np.expand_dims(frames,axis=2)
    frames
    #c = colormap.Colormap()
    #frames=(frames-l)*(255-80)/(r-l)+80
    frames=np.where(frames<=l,0,frames )
    frames=np.where(frames>=r,r,frames )
    frames=frames*s
    
    return frames

def compute_similarity(a,b):
    mses=[]
    ss=[]
    psnrs=[]
    for i in range(len(a)):
        mse=np.sqrt(np.mean(np.square(a[i]-b[i])))
        s=ssim(a[i],b[i],data_range=a[i].max()-a[i].min(),multichannel=True)
        psnr=20*np.log10(a[i].max()/mse)
        mses.append(mse)
        ss.append(s)
        psnrs.append(psnr)
    return mses,ss,psnrs


def compute_run_length_encoding_len(frame):
    frame_flattened=frame.flatten()
    starts,lengths,values=rlencode(frame_flattened)

    return len(lengths)


def rlencode(x, dropna=False):
    """
    Run length encoding.
    Based on http://stackoverflow.com/a/32681075, which is based on the rle 
    function from R.
    
    Parameters
    ----------
    x : 1D array_like
        Input array to encode
    dropna: bool, optional
        Drop all runs of NaNs.
    
    Returns
    -------
    start positions, run lengths, run values
    
    """
    where = np.flatnonzero
    x = np.asarray(x)
    n = len(x)
    if n == 0:
        return (np.array([], dtype=int), 
                np.array([], dtype=int), 
                np.array([], dtype=x.dtype))

    starts = np.r_[0, where(~np.isclose(x[1:], x[:-1],atol=1e-02, equal_nan=True)) + 1]
    lengths = np.diff(np.r_[starts, n])
    values = x[starts]
    
    if dropna:
        mask = ~np.isnan(values)
        starts, lengths, values = starts[mask], lengths[mask], values[mask]
    
    return starts, lengths, values