/*************************************************
Copyright:
Author: Lina Yu, Hongfeng Yu
Description: Read .dat file and compute the max 
and min occlusion for each vertex and find the hole
Output: occlusion.dat //  x coordinate
        boundary.dat  //  y coordinate
**************************************************/

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cassert>
#include <sys/times.h>
#include <cuda.h>
#include <float.h>
// #include "helper_cuda.h"


using namespace std;

int theBlockSize = 64;

int iDivUp(int a, int b)
{
    return (a % b != 0) ? (a / b + 1) : (a / b);
}

typedef struct {
    float x;
    float y;    
} data_t;

class vec3d{
public:
    float length() {
        return sqrt(x*x + y*y + z*z);
    }

    void normalize() {
        float len = length();
        if (len != 0) {
            x /= len;
            y /= len;
            z /= len;
        }
    }

    float x;
    float y;
    float z;
};



__device__ bool interpolate_coor(float x, float y, float z, float *data, int sx, int sy, int sz, float *result)
{
 
    if (x < 0 || x > (sx -1)) {
        return false;
    }

    if (y < 0 || y > (sy -1)) {
        return false;
    }

    if (z < 0 || z > (sz -1)) {
        return false;
    }    

    
    int x_0 = (int)x;
    int y_0 = (int)y;
    int z_0 = (int)z;
    
    
    if (x_0 == x && y_0 == y && z_0 == z) {
        *result = data[ x_0 + y_0 * sx + z_0 * sx * sy];
        return true;
    }

    int x_1 = x_0 + 1;
    int y_1 = y_0 + 1;
    int z_1 = z_0 + 1;

    if (x_1 > sx - 1) {
        x_1 = sx - 1;
    }

    if (y_1 > sy - 1) {
        y_1 = sy - 1;
    }

    if (z_1 > sz - 1) {
        z_1 = sz - 1;
    }    
    
    float scalar_x = (x - x_0 * 1.0f) / 1.0f; 
    float scalar_y = (y - y_0 * 1.0f) / 1.0f;
    float scalar_z = (z - z_0 * 1.0f) / 1.0f;

    float v_00 = (float)data[ x_0 + y_0 * sx + z_0 * sx * sy ] * (1 - scalar_x) + 
                 (float)data[ x_1 + y_0 * sx + z_0 * sx * sy ] * scalar_x;

    float v_10 = (float)data[ x_0 + y_1 * sx + z_0 * sx * sy ] * (1 - scalar_x) + 
                 (float)data[ x_1 + y_1 * sx + z_0 * sx * sy ] * scalar_x;

    float v_01 = (float)data[ x_0 + y_0 * sx + z_1 * sx * sy ] * (1 - scalar_x) + 
                 (float)data[ x_1 + y_0 * sx + z_1 * sx * sy ] * scalar_x;//should use z_1 here

    float v_11 = (float)data[ x_0 + y_1 * sx + z_1 * sx * sy ] * (1 - scalar_x) + 
                 (float)data[ x_1 + y_1 * sx + z_1 * sx * sy ] * scalar_x;   //should use z_1 here
    float v_0 = v_00 * (1 - scalar_y) + v_10 * scalar_y;
    float v_1 = v_01 * (1 - scalar_y) + v_11 * scalar_y;

    float v = v_0 * (1 - scalar_z) + v_1 * scalar_z;

    *result = v;

    return true;
}



__global__ void occl_kernel_3d_mask(float *in_data, 
                                    float *ou_occl_sum,
                                    float *ou_occl_std,
                                    int sx, 
                                    int sy, 
                                    int sz,
                                    int kernel,
                                    float max_data)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int idy = blockIdx.y * blockDim.y + threadIdx.y;
    int idz = blockIdx.z * blockDim.z + threadIdx.z;

    if ((idx >= sx) || (idy >= sy) || (idz >= sz)) {
        return;
    }

    int x = idx;
    int y = idy;
    int z = idz;
    
    int k = kernel;
    // int kk = k * k;
   
    float sum = 0.0f;
    float result = 0.0f;   
    for (int dz = -k; dz <= k ; dz++){
        for (int dy = -k; dy <= k; dy++) {
            for (int dx = -k; dx <= k; dx++) {

                // float rr = dx * dx + dy * dy + dz * dz;
                
                // if (rr > kk) {
                //     continue;
                // }
                if (x + dx >= sx || x + dx < 0 || 
                    y + dy >= sy || y + dy < 0 || 
                    z + dz >= sz || z + dz < 0 )
                {
                    continue;
                }

                if (interpolate_coor(x + dx, y + dy, z + dz, in_data, 
                                    sx, sy, sz, &result) == true)
                {
                    sum += result;
                }

            }
        }
    }
    

    int id = idx + idy * sx + idz * sx * sy;    
    ou_occl_sum[id] = sum;

    float avg = sum / 360.0f / 180.0f;

    float vx, vy, vz;
    float std = 0.0f;
    float theta_sin = 0.0f;
    float theta_cos = 0.0f;
    float psi_sin = 0.0f;
    float psi_cos = 0.0f;
    float pi_factor = 3.1415f / 180.0f;

    float tmp_t;

    for (int theta = 0; theta < 60; theta++) {
        for (int psi = 0; psi < 360; psi+=2) {

            tmp_t = 0;

            for (int r = 0; r < k; r++) {
        
                theta_sin = sin(theta * pi_factor);
                theta_cos = cos(theta * pi_factor);
                psi_sin   = sin(psi * pi_factor);
                psi_cos   = cos(psi * pi_factor);

                vx = r * theta_sin * psi_cos;
                vy = r * theta_sin * psi_sin;
                vz = r * theta_cos;
                
                if (interpolate_coor(x + vx, y + vy, z + vz, in_data, sx, sy, sz, &result) == true)
                {
                    tmp_t += result;
                }else{
                    tmp_t += max_data;
                }
            }

            std += (tmp_t - avg) * (tmp_t - avg);
        }  
    }

    for (int theta = 60; theta < 120; theta++) {
        for (int psi = 0; psi < 360; psi+=2) {

            tmp_t = 0;

            for (int r = 0; r < k; r++) {
        
                theta_sin = sin(theta * pi_factor);
                theta_cos = cos(theta * pi_factor);
                psi_sin   = sin(psi * pi_factor);
                psi_cos   = cos(psi * pi_factor);

                vx = r * theta_sin * psi_cos;
                vy = r * theta_sin * psi_sin;
                vz = r * theta_cos;
                
                if (interpolate_coor(x + vx, y + vy, z + vz, in_data, sx, sy, sz, &result) == true)
                {
                    tmp_t += result;
                }else{
                    tmp_t += max_data;
                }
            }

            std += (tmp_t - avg) * (tmp_t - avg);
        }        
    }


    for (int theta = 120; theta < 180; theta++) {
        for (int psi = 0; psi < 360; psi+=2) {

            tmp_t = 0;

            for (int r = 0; r < k; r++) {
        
                theta_sin = sin(theta * pi_factor);
                theta_cos = cos(theta * pi_factor);
                psi_sin   = sin(psi * pi_factor);
                psi_cos   = cos(psi * pi_factor);

                vx = r * theta_sin * psi_cos;
                vy = r * theta_sin * psi_sin;
                vz = r * theta_cos;
                
                if (interpolate_coor(x + vx, y + vy, z + vz, in_data, sx, sy, sz, &result) == true)
                {
                    tmp_t += result;
                }else{
                    tmp_t += max_data;
                }
            }

            std += (tmp_t - avg) * (tmp_t - avg);
        }     
    }


    std = (float)sqrt(std);
    ou_occl_std[id] = std;
   
}




int main(int argc, char **argv)
{   
    if (argc != 7) {
        cout << "Usage: " << argv[0] << " in_file x y z out_file_sum out_file_std"  << endl;
        return 0;
    }

    char *input_file        = argv[1];
    int  sx                 = atoi(argv[2]);
    int  sy                 = atoi(argv[3]);
    int  sz                 = atoi(argv[4]);
    char *output_file_sum   = argv[5];
    char *output_file_std   = argv[6];
    int  kernel             = (int)sqrt((sx - 1) * (sx - 1) + (sy - 1) * (sy - 1) + (sz - 1) * (sz - 1));

    //data array
    float *in_data      = NULL;
    float *ou_occl_sum  = NULL;
    float *ou_occl_std  = NULL;
    // float max_data = FLT_MIN;
    
    //number of data items
    int st = sx * sy * sz;

    in_data     = new float[st];
    ou_occl_sum = new float[st];
    ou_occl_std = new float[st];

    memset(in_data,     0, sizeof(float) * st);
    memset(ou_occl_sum, 0, sizeof(float) * st);
    memset(ou_occl_std, 0, sizeof(float) * st);
    
    //read the raw data file
    ifstream inf(input_file);
    if (!inf) {
        cerr << "Cannot open " << input_file << endl;
        return 0;
    }
    cout << "Open " << input_file << endl;
    inf.read(reinterpret_cast<char *>(in_data), sizeof(float) * st);
    inf.close();

    float max_data = in_data[0];
    for (int i = 0; i < st; ++i)
    {
        if (in_data[i] > max_data)
        {
            max_data = in_data[i];
        }
    }

    cout<< "max = "<< max_data << endl;
    float *cuda_in_data     = NULL;
    float *cuda_ou_occl_sum = NULL;    
    float *cuda_ou_occl_std = NULL;
    
    cudaMalloc((void **) &cuda_in_data,     sizeof(float) * st);
    cudaMalloc((void **) &cuda_ou_occl_sum, sizeof(float) * st);    
    cudaMalloc((void **) &cuda_ou_occl_std, sizeof(float) * st);  

    cudaMemcpy(cuda_in_data, in_data, sizeof(float) * st, 
               cudaMemcpyHostToDevice);
    cudaMemcpy(cuda_ou_occl_sum, ou_occl_sum, sizeof(float) * st, 
               cudaMemcpyHostToDevice);
    cudaMemcpy(cuda_ou_occl_std, ou_occl_std, sizeof(float) * st, 
               cudaMemcpyHostToDevice);
    
    dim3 blockSize(32, 32, 1);
    dim3 gridSize = dim3(iDivUp(sx, blockSize.x), 
                         iDivUp(sy, blockSize.y), 
                         iDivUp(sz, blockSize.z));

    //cudaDeviceSetLimit(cudaLimitMallocHeapSize, 128*1024*1024);


    occl_kernel_3d_mask <<< gridSize, blockSize >>> (
        cuda_in_data,
        cuda_ou_occl_sum,
        cuda_ou_occl_std,
        sx,
        sy,
        sz,
        kernel,
        max_data);

    
    cudaThreadSynchronize();
    
    cudaMemcpy(ou_occl_sum, cuda_ou_occl_sum, sizeof(float) * st, 
               cudaMemcpyDeviceToHost);
   
    cudaMemcpy(ou_occl_std, cuda_ou_occl_std, sizeof(float) * st, 
               cudaMemcpyDeviceToHost);
   
    cudaFree(cuda_in_data);
    cudaFree(cuda_ou_occl_sum);
    cudaFree(cuda_ou_occl_std);
    cudaThreadExit();

    ofstream outf;
    char filename[1024];
    
    sprintf(filename, "%s.dat", output_file_sum);
    cout << "write " << filename << endl;
    outf.open(filename);
    assert(outf);    
    outf.write(reinterpret_cast<const char *>(ou_occl_sum), 
               sizeof(float) * st);
    outf.close();
    
    sprintf(filename, "%s.hdr", output_file_sum);
    cout << "write " << filename << endl;
    outf.open(filename);
    assert(outf);    
    outf << sx << " " << sy << " " << sz << endl;
    outf << "float" << endl;
    outf.close();
    
    sprintf(filename, "%s.dat", output_file_std);
    cout << "write " << filename << endl;
    outf.open(filename);    
    assert(outf);    
    outf.write(reinterpret_cast<const char *>(ou_occl_std), 
               sizeof(float) * st);
    outf.close();
    
    sprintf(filename, "%s.hdr", output_file_std);
    cout << "write " << filename << endl;
    outf.open(filename);
    assert(outf);    
    outf << sx << " " << sy << " " << sz << endl;
    outf << "float" << endl;
    outf.close();
    
    delete [] in_data;
    delete [] ou_occl_sum;
    delete [] ou_occl_std;
    
    return 0;
}


