/*************************************************
 * Copyright:
 * Author: Lina Yu, Hongfeng Yu 
 * Description: Read .dat file and compute the max 
 * and min occlusion for each vertex and find the hole
 * Output: occlusion.dat //  x coordinate
 *        boundary.dat  //  y coordinate
 **************************************************/

#include <iostream>
#include <fstream>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cassert>
#include <sys/times.h>
#include <cuda.h>

using namespace std;

int iDivUp(int a, int b)
{
    return (a % b != 0) ? (a / b + 1) : (a / b);
}


typedef struct {
    float x;
    float y;    
} data_t;

class vec3d{
public:
    float length() {
        return sqrt(x*x + y*y + z*z);
    }
    
    void normalize() {
        float len = length();
        if (len != 0) {
            x /= len;
            y /= len;
            z /= len;
        }
    }
    
    float x;
    float y;
    float z;
};



inline __device__ float interpolate_coor(float x, float y, float z, float *data, int sx, int sy, int sz)
{
    /*
    if (x < 0 || x > (sx -1)) {
        return 0;
    }
    
    if (y < 0 || y > (sy -1)) {
        return 0;
    }
    
    if (z < 0 || z > (sz -1)) {
        return 0;
    }
    */
    
    if (x < 0) x = 0;
    if (y < 0) y = 0;
    if (z < 0) z = 0;
    
    if (x > (sx -1)) x = sx - 1;
    if (y > (sy -1)) y = sy - 1;
    if (z > (sz -1)) z = sz - 1;
    
    int x_0 = (int)x;
    int y_0 = (int)y;
    int z_0 = (int)z;
    
    if (x_0 == x && y_0 == y && z_0 == z) {
        return data[ x_0 + y_0 * sx + z_0 * sx * sy];
    }
    
    int x_1 = x_0 + 1;
    int y_1 = y_0 + 1;
    int z_1 = z_0 + 1;
    
    if (x_1 > sx - 1) {
        x_1 = sx - 1;
    }
    
    if (y_1 > sy - 1) {
        y_1 = sy - 1;
    }
    
    if (z_1 > sz - 1) {
        z_1 = sz - 1;
    }  
    
    float scalar_x = (x - x_0 * 1.0f) / 1.0f; 
    float scalar_y = (y - y_0 * 1.0f) / 1.0f;
    float scalar_z = (z - z_0 * 1.0f) / 1.0f;
    
    float v_00 = (float)data[ x_0 + y_0 * sx + z_0 * sx * sy ] * (1 - scalar_x) + 
    (float)data[ x_1 + y_0 * sx + z_0 * sx * sy ] * scalar_x;
    
    float v_10 = (float)data[ x_0 + y_1 * sx + z_0 * sx * sy ] * (1 - scalar_x) + 
    (float)data[ x_1 + y_1 * sx + z_0 * sx * sy ] * scalar_x;
    
    float v_01 = (float)data[ x_0 + y_0 * sx + z_1 * sx * sy ] * (1 - scalar_x) + 
    (float)data[ x_1 + y_0 * sx + z_1 * sx * sy ] * scalar_x;
    
    float v_11 = (float)data[ x_0 + y_1 * sx + z_1 * sx * sy ] * (1 - scalar_x) + 
    (float)data[ x_1 + y_1 * sx + z_1 * sx * sy ] * scalar_x;
    
    
    
    float v_0 = v_00 * (1 - scalar_y) + v_10 * scalar_y;
    float v_1 = v_01 * (1 - scalar_y) + v_11 * scalar_y;
    
    float v = v_0 * (1 - scalar_z) + v_1 * scalar_z;
    
    return v;
}


__global__ void lapl_kernel_2d(float *data, 
                               float *lapl,
                               float *filter,
                               int sx, 
                               int sy, 
                               int sz)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int idy = blockIdx.y * blockDim.y + threadIdx.y;
    int idz = blockIdx.z * blockDim.z + threadIdx.z;
    
    if ((idx >= sx) || (idy >= sy) || (idz >= sz)) {
        return;
    }
    
    int x = idx;
    int y = idy;
    int z = idz;
    
    float v[9];
    
    v[0] = interpolate_coor(x-1, y+1, z, data, sx, sy, sz);
    v[1] = interpolate_coor(x,   y+1, z, data, sx, sy, sz);
    v[2] = interpolate_coor(x+1, y+1, z, data, sx, sy, sz);
    v[3] = interpolate_coor(x-1, y,   z, data, sx, sy, sz);
    v[4] = interpolate_coor(x,   y,   z, data, sx, sy, sz);
    v[5] = interpolate_coor(x+1, y,   z, data, sx, sy, sz);
    v[6] = interpolate_coor(x-1, y-1, z, data, sx, sy, sz);
    v[7] = interpolate_coor(x,   y-1, z, data, sx, sy, sz);
    v[8] = interpolate_coor(x+1, y-1, z, data, sx, sy, sz);
    
    float t = 0;
    for (int i = 0; i < 9; i++) {
        t += v[i] * filter[i];
    }
    
    int id = idx + idy * sx + idz * sx * sy;
    
    lapl[id] = t;
}

__global__ void lapl_kernel_3d(float *data, 
                               float *lapl,
                               int sx, 
                               int sy, 
                               int sz)
{
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int idy = blockIdx.y * blockDim.y + threadIdx.y;
    int idz = blockIdx.z * blockDim.z + threadIdx.z;
    
    if ((idx >= sx) || (idy >= sy) || (idz >= sz)) {
        return;
    }
    
    int x = idx;
    int y = idy;
    int z = idz;
    
    float t = 0;
    
    t += 1  * interpolate_coor(x-1, y,   z,   data, sx, sy, sz);
    t += 1  * interpolate_coor(x+1, y,   z,   data, sx, sy, sz);
    t += 1  * interpolate_coor(x,   y-1, z,   data, sx, sy, sz);
    t += 1  * interpolate_coor(x,   y+1, z,   data, sx, sy, sz);
    t += 1  * interpolate_coor(x,   y,   z-1, data, sx, sy, sz);
    t += 1  * interpolate_coor(x,   y,   z+1, data, sx, sy, sz);
    t += -6 * interpolate_coor(x,   y,   z,   data, sx, sy, sz);
    
    int id = idx + idy * sx + idz * sx * sy;
    
    lapl[id] = t;
}


int main(int argc, char **argv)
{   
    if (argc != 6) {
        cout << "Usage: " << argv[0] << " in_file x y z out_file"  << endl;
        return 0;
    }
    
    char *input_file = argv[1];
    int  sx  = atoi(argv[2]);
    int  sy = atoi(argv[3]);
    int  sz  = atoi(argv[4]);
    char *output_file = argv[5];
    
    //data array
    float *data = NULL;
    float *lapl = NULL;
    float lapl_2d_kernel[] 
        = {-1, -1, -1,
           -1,  8, -1,
           -1, -1, -1};
        
    //number of data items
    int num = sx * sy * sz;
    
    data = new float[num];
    lapl = new float[num];
    
    memset(data, 0, sizeof(float) * num);
    memset(lapl, 0, sizeof(float) * num);
    
    //read the raw data file
    ifstream inf(input_file);
    if (!inf) {
        cerr << "Cannot open " << input_file << endl;
        return 0;
    }
    cout << "Open " << input_file << endl;
    inf.read(reinterpret_cast<char *>(data), sizeof(float) * num);
    inf.close();
    
    float *cuda_data = NULL;
    float *cuda_lapl = NULL;
    float *cuda_filter = NULL;
    
    cudaMalloc((void **) &cuda_data, sizeof(float) * num);
    cudaMalloc((void **) &cuda_lapl, sizeof(float) * num);
    cudaMalloc((void **) &cuda_filter, sizeof(float) * 9);
    
    cudaMemcpy(cuda_data, data, sizeof(float) * num, 
                cudaMemcpyHostToDevice);
    
    cudaMemcpy(cuda_filter, lapl_2d_kernel, sizeof(float) * 9, 
                cudaMemcpyHostToDevice);
    
    dim3 blockSize(32, 32, 1);
    dim3 gridSize = dim3(iDivUp(sx, blockSize.x), 
                         iDivUp(sy, blockSize.y), 
                         iDivUp(sz, blockSize.z));
    
    if (sz == 1) {
        lapl_kernel_2d<<<gridSize, blockSize>>> (
            cuda_data,
            cuda_lapl,
            cuda_filter,
            sx,
            sy,
            sz);
    } else {
        lapl_kernel_3d<<<gridSize, blockSize>>> (
            cuda_data,
            cuda_lapl,
            sx,
            sy,
            sz);
    }
    
    cudaThreadSynchronize();
    
    cudaMemcpy(lapl, cuda_lapl, sizeof(float) * num, cudaMemcpyDeviceToHost);
    
    cudaFree(cuda_data);
    cudaFree(cuda_lapl);
    
    cudaThreadExit();
    
    //normalization    
    float min, max;
    min = max = lapl[0];
    for (int i = 0; i < num; i++) {
        lapl[i] = abs(lapl[i]);
        if (min > lapl[i]) min = lapl[i];
        if (max < lapl[i]) max = lapl[i];
    }
    
    cout << "min " << min << " max " << max << endl;
    
    
    for (int i = 0; i < num; i++) {
        lapl[i] = (lapl[i] - min) / (max - min);
    }
    
    
    ofstream outf;
    char filename[1024];
        
    sprintf(filename, "%s.dat", output_file);
    cout << "write " << filename << endl;
    outf.open(filename);    
    assert(outf);    
    outf.write(reinterpret_cast<const char *>(lapl), 
               sizeof(float) * num);
    outf.close();
    
    sprintf(filename, "%s.hdr", output_file);
    cout << "write " << filename << endl;
    outf.open(filename);
    assert(outf);    
    outf << sx << " " << sy << " " << sz << endl;
    outf << "float" << endl;
    outf.close();
    
    delete [] data;
    delete [] lapl;
    
    return 0;
}


