#include "maxweight.h"

MaxWeight::MaxWeight()
{
}


int MaxWeight::max(int a, int b) {
    int max = 0;
    if (a < b)
        max = b;
    else max = a;
    return max;
}

int MaxWeight::min(int a, int b) {
    int min = 0;
    if (a < b)
        min = a;
    else min = b;
    return min;
}

void MaxWeight::fillMatrix(int** matrix, int row_num, int col_num, int max) {
    int min = 100000000;
    int continue_tag = 0;
    int i = 0;
    int j = 0;
    for (i = 0; i < row_num; i++)
        for (j = 0; j < col_num; j++) {
            if (matrix[i][j] < min && matrix[i][j] != 0)
                min = matrix[i][j];
        }
    if (row_num > col_num) {
        continue_tag = col_num;
        for (i = 0; i < max; i++)
            for (j = continue_tag; j < max; j++) {
                matrix[i][j] = min;
            }
    } else {
        continue_tag = row_num;
        for (i = row_num; i < max; i++)
            for (j = 0; j < col_num; j++) {
                matrix[i][j] = min;
            }
    }
}

void MaxWeight::init_labels(int* lx, int* ly, int num, int** cost) {
    memset(lx, 0, num * sizeof(int));
    memset(ly, 0, num * sizeof(int));
    for (int i = 0; i < num; i++) {
        for (int j = 0; j < num; j++) {
            lx[i] = max(lx[i], cost[i][j]);
        }
    }
    //for (int i = 0; i < num; i++)
    //	cout<<lx[i]<<endl;
}

void MaxWeight::update_labels(int* lx, int* ly, int num, bool* S, bool* T,
                              int* slack) {
    int x, y;
    int delta = 100000000;
    for (y = 0; y < num; y++) {
        if (!T[y]) {
            delta = min(delta, slack[y]);
        }
    }
    for (x = 0; x < num; x++) {
        if (S[x])
            lx[x] -= delta;
    }
    for (y = 0; y < num; y++) {
        if (T[y])
            ly[y] += delta;
    }
    for (y = 0; y < num; y++) {
        if (!T[y])
            slack[y] -= delta;
    }
}

/**
    current_x- current vertex
    prevx - vertex from X before current_x in the alternating path
    So we add edges (prevx, xy[current_x]), (xy[current_x], current_x)
**/
void MaxWeight::add_to_tree(int* lx, int* ly, int num, int* prev, int current_x,
                            int prevx, int** cost, int* slack, int* slackx,
                            bool* S) {
    S[current_x] = true;
    prev[current_x] = prevx;
    for (int y = 0; y < num; y++) {
        if (lx[current_x] + ly[y] - cost[current_x][y] < slack[y]) {
            slack[y] = lx[current_x] + ly[y] - cost[current_x][y];
            slackx[y] = current_x;
        }
    }
}

//q is the queue for BFS
void MaxWeight::augment(int max_match, int num, int* q, bool* S, bool* T,
                        int* prev, int* xy, int* yx, int* lx, int* ly,
                        int* slack, int* slackx, int** cost) {
    int x, y, root;
    while (max_match != num) {
        //int ty;
        int wr = 0, rd = 0;

        memset(S, false, num*sizeof(bool));
        memset(T, false, num*sizeof(bool));
        memset(prev, -1, num*sizeof(int));
        for (x = 0; x < num; x++) {
            if (xy[x] == -1) {
                q[wr++] = root = x;
                prev[x] = -2;
                S[x] = true;
                break;
            }
        }
        for (y = 0; y < num; y++) {
            slack[y] = lx[root] + ly[y] - cost[root][y];
            slackx[y] = root;
        }
        while (true) {
            while (rd < wr) {
                x = q[rd++];
                for (y = 0; y < num; y++) {
                    if ((cost[x][y] == lx[x] + ly[y]) && (!T[y])) {
                        if (yx[y] == -1) break;
                        T[y] = true;
                        q[wr++] = yx[y];
                        add_to_tree(lx, ly, num, prev, yx[y], x, cost, slack,
                                    slackx, S);
                    }
                }
                if (y < num) break;
            }
            if (y < num) break;

            update_labels(lx, ly, num, S, T, slack);
            wr = rd = 0;

            for (y = 0; y < num; y++) {
                if (!T[y] && slack[y] == 0) {
                    if (yx[y] == -1) {
                        x = slackx[y];
                        break;
                    } else {
                        T[y] = true;
                        if (!S[yx[y]]) {
                            q[wr++] = yx[y];

                            add_to_tree(lx, ly, num, prev, yx[y], slackx[y],
                                        cost, slack, slackx, S);
                        }
                    }
                }
            }
            if (y < num) break;
        }
        if (y < num) {
            max_match++;
            for (int cx = x, cy = y, ty; cx != -2; cx = prev[cx], cy = ty) {
                ty = xy[cx];
                yx[cy] = cx;
                xy[cx] = cy;
            }
        }
    }
}

int MaxWeight::hungarian(int* lx, int* ly, int num, int* q, bool* S, bool* T,
                         int* prev, int* xy, int* yx, int** cost, int* slack,
                         int* slackx, int row_num, int col_num) {
    int ret = 0;
    int max_match = 0;
    memset(xy, -1, num*sizeof(int));
    memset(yx, -1, num*sizeof(int));
    memset(q, 0, num*sizeof(int));
    init_labels(lx, ly, num, cost);
    augment(max_match, num, q, S, T, prev, xy, yx, lx, ly, slack, slackx, cost);
    if (row_num < col_num) {
        for (int x = 0; x < row_num; x++)
            ret += cost[x][xy[x]];
    } else {
        for (int y = 0; y < col_num; y++)
            ret += cost[yx[y]][y];
    }
    return ret;
}

//nodeMap: in this map, key is new node id, value is original node id
map< int, set<int> > MaxWeight::pickLabelMap(GRAPH* g, int nodeNum,
                                             map<int, int> nodeMap) {
    int i = 0;
    map <int, set<int> > labelMap;
    map <int, set<int> >::iterator it_labelMap;
    set<int> tempSet;
    int nodeid = 0;
    for (i = 1; i <= nodeNum; i++) {
        it_labelMap = labelMap.find((g->adjList[i]).vertex_label);
        nodeid = nodeMap[(g->adjList[i]).vertex_id];
        if (it_labelMap == labelMap.end()) {
            tempSet.insert(nodeid);
            labelMap.insert(pair< int, set<int> >((g->adjList[i]).vertex_label,
                                                  tempSet));
            tempSet.clear();
        } else {
            it_labelMap->second.insert(nodeid);
        }
    }
    return labelMap;
}

void MaxWeight::findIntersectionNum(map< int, set<int> > partition_1,
                                    map< int, set<int> > partition_2,
                                    int** a, int* clusters_1, int* clusters_2) {
    //int i;
    //int row = partition_1.size();
    //int col = partition_2.size();
    int row_count = 0;
    int col_count = 0;
    map < int, set<int> >::iterator it_partition_1;
    map < int, set<int> >::iterator it_partition_2;
    set<int>::iterator it_set1;
    set<int>::iterator it_set2;
    int intersection_count = 0;
    int clusters_num = 0;
    for (it_partition_1 = partition_1.begin();
         it_partition_1 != partition_1.end(); it_partition_1++) {
        for (it_partition_2 = partition_2.begin();
             it_partition_2 != partition_2.end(); it_partition_2++) {
            it_set1 = it_partition_1->second.begin();
            it_set2 = it_partition_2->second.begin();
            while (it_set1 != it_partition_1->second.end() &&
                   it_set2 != it_partition_2->second.end()) {
                if (*it_set1 < *it_set2 ) {
                    it_set1++;
                } else if (*it_set1 > *it_set2) {
                    it_set2++;
                } else {
                    it_set1++;
                    it_set2++;
                    intersection_count += 1;
                }
            }
            a[row_count][col_count] = intersection_count;
            col_count++;
            intersection_count = 0;
        }
        intersection_count = 0;
        row_count++;
        col_count = 0;
    }
    for (it_partition_1 = partition_1.begin();
         it_partition_1 != partition_1.end(); it_partition_1++) {
        clusters_1[clusters_num] = it_partition_1->second.size();
        clusters_num++;
    }
    clusters_num = 0;
    for (it_partition_2 = partition_2.begin();
         it_partition_2 != partition_2.end(); it_partition_2++) {
        clusters_2[clusters_num] = it_partition_2->second.size();
        clusters_num++;
    }
}

int MaxWeight::findOptimal(int** a, int rowNum, int colNum, int* marked) {
    int max = 0;
    int total_result = 0;
    bool* visited = new bool[colNum];
    int marked_j = 0;
    memset(visited, false, colNum);
    for (int i = 0; i < rowNum; i++)
        for (int j = 0; j < colNum; j++) {
            if (a[i][j] > max && !visited[j]) {
                max = a[i][j];
                marked_j = j;
            }
        }
        marked[marked_j] = 1;
        marked_j = 0;
        total_result = total_result + max;
        max = 0;
    return total_result;
}
