001/*
002 * Copyright 2010-2015 Institut Pasteur.
003 * 
004 * This file is part of Icy.
005 * 
006 * Icy is free software: you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation, either version 3 of the License, or
009 * (at your option) any later version.
010 * 
011 * Icy is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
014 * GNU General Public License for more details.
015 * 
016 * You should have received a copy of the GNU General Public License
017 * along with Icy. If not, see <http://www.gnu.org/licenses/>.
018 */
019package icy.math;
020
021import java.util.Arrays;
022
023/**
024 * Implementation of the Hungarian / Munkres-Kuhn algorithm<br>
025 * for rectangular assignment problem.
026 * 
027 * @author Nicolas Chenouard & Stephane
028 */
029public class HungarianAlgorithm
030{
031    final int numRow;
032    final int numCol;
033    final int k;
034    final double[][] costs;
035
036    final int[] rowsStar;
037    final int[] colsStar;
038    final int[] rowsPrime;
039    final boolean[] rowsCovered;
040    final boolean[] colsCovered;
041
042    final int[] colsUnStar;
043    final int[] rowsDoStar;
044
045    int step;
046    boolean done;
047
048    /**
049     * Create the optimizer.
050     * 
051     * @param values
052     *        Table of assignment costs.<br>
053     */
054    public HungarianAlgorithm(double[][] values)
055    {
056        int r, c;
057        final int initialNumCol = values[0].length;
058
059        numRow = values.length;
060        // number of column should >= number of row
061        numCol = Math.max(initialNumCol, numRow);
062        k = Math.min(numRow, numCol);
063        costs = new double[numRow][numCol];
064
065        // find maximum value
066        double max = values[0][0];
067        if (initialNumCol < numRow)
068        {
069            for (r = 0; r < values.length; r++)
070            {
071                final double v = ArrayMath.max(values[r]);
072                if (v > max)
073                    max = v;
074            }
075        }
076
077        // enlarge matrix with max value if necessary
078        for (r = 0; r < values.length; r++)
079        {
080            final double[] rowValues = values[r];
081            final double[] rowCosts = costs[r];
082
083            for (c = 0; c < rowValues.length; c++)
084                rowCosts[c] = rowValues[c];
085            for (; c < numCol; c++)
086                rowCosts[c] = max;
087        }
088
089        rowsStar = new int[numRow];
090        colsStar = new int[numCol];
091        rowsPrime = new int[numRow];
092        rowsCovered = new boolean[numRow];
093        colsCovered = new boolean[numCol];
094
095        colsUnStar = new int[numCol];
096        rowsDoStar = new int[numRow];
097
098        Arrays.fill(rowsPrime, -1);
099    }
100
101    /**
102     * Resolve and returns result in this form : <code>result[row] = column</code>
103     */
104    public int[] resolve()
105    {
106        initialReduce();
107
108        done = false;
109        step = 2;
110        while (!done)
111        {
112            switch (step)
113            {
114                case 2:
115                    updateStar();
116                    break;
117
118                case 3:
119                    doColCover();
120                    break;
121
122                case 4:
123                    doPrime();
124                    break;
125
126                case 5:
127                    // done inner 4
128                    break;
129
130                case 6:
131                    reduce();
132                    break;
133            }
134        }
135
136        return rowsStar;
137    }
138
139    // For each row we find the row minimum and subtract it from all entries on that row.
140    private void initialReduce()
141    {
142        for (int r = 0; r < numRow; r++)
143        {
144            final double[] rowCosts = costs[r];
145
146            // get row minimum cost
147            final double min = ArrayMath.min(rowCosts);
148
149            // subtract it to all entries
150            for (int c = 0; c < numCol; c++)
151                rowCosts[c] -= min;
152        }
153    }
154
155    // update starring
156    private void updateStar()
157    {
158        Arrays.fill(rowsStar, -1);
159        Arrays.fill(colsStar, -1);
160
161        for (int r = 0; r < numRow; r++)
162            updateRowStar(r);
163
164        step = 3;
165    }
166
167    private void updateRowStar(int r)
168    {
169        final double[] rowCosts = costs[r];
170
171        for (int c = 0; c < numCol; c++)
172        {
173            if (colsStar[c] == -1)
174            {
175                if (rowCosts[c] == 0)
176                {
177                    rowsStar[r] = c;
178                    colsStar[c] = r;
179                    return;
180                }
181            }
182        }
183    }
184
185    // cover column with contained star
186    private void doColCover()
187    {
188        Arrays.fill(colsCovered, false);
189
190        int numColCovered = 0;
191        for (int c = 0; c < numCol; c++)
192        {
193            if (colsStar[c] != -1)
194            {
195                colsCovered[c] = true;
196                numColCovered++;
197            }
198        }
199
200        if (numColCovered == k)
201            done = true;
202        else
203            step = 4;
204    }
205
206    // prime uncovered zero
207    private void doPrime()
208    {
209        for (int c = 0; c < numCol; c++)
210            if (!colsCovered[c])
211                if (doPrimCol(c))
212                    return;
213
214        step = 6;
215    }
216
217    // prime specified column
218    private boolean doPrimCol(int c)
219    {
220        for (int r = 0; r < numRow; r++)
221        {
222            if (!rowsCovered[r])
223            {
224                // no covered zero ?
225                if (costs[r][c] == 0)
226                {
227                    // prime it
228                    rowsPrime[r] = c;
229
230                    // get star column for this row ?
231                    final int starCol = rowsStar[r];
232
233                    // no star on this row
234                    if (starCol == -1)
235                    {
236                        convertPrimeToStar(r, c);
237                        return true;
238                    }
239
240                    rowsCovered[r] = true;
241                    colsCovered[starCol] = false;
242
243                    // so we don't forget newly uncovered zeros
244                    if (starCol < c)
245                        if (doPrimCol(starCol))
246                            return true;
247                }
248            }
249        }
250
251        return false;
252    }
253
254    // convert all prime found on the way to star
255    private void convertPrimeToStar(int r, int c)
256    {
257        int nb = 0;
258
259        int primeCol = c;
260        int starRow = colsStar[primeCol];
261
262        while (starRow != -1)
263        {
264            colsUnStar[nb] = primeCol;
265            rowsDoStar[nb] = starRow;
266            nb++;
267
268            primeCol = rowsPrime[starRow];
269            starRow = colsStar[primeCol];
270        }
271
272        for (int i = 0; i < nb; i++)
273        {
274            final int startCol = colsUnStar[i];
275
276            // unstar
277            rowsStar[colsStar[startCol]] = -1;
278            colsStar[startCol] = -1;
279        }
280
281        for (int i = 0; i < nb; i++)
282        {
283            final int primeRow = rowsDoStar[i];
284            final int pc = rowsPrime[primeRow];
285
286            // star
287            colsStar[pc] = primeRow;
288            rowsStar[primeRow] = pc;
289        }
290        // star
291        colsStar[c] = r;
292        rowsStar[r] = c;
293
294        Arrays.fill(rowsPrime, -1);
295        Arrays.fill(rowsCovered, false);
296        Arrays.fill(colsCovered, false);
297
298        step = 3;
299    }
300
301    // reduce costs
302    private void reduce()
303    {
304        double min = Double.MAX_VALUE;
305
306        // find minimum of uncovered elements
307        for (int r = 0; r < numRow; r++)
308        {
309            if (!rowsCovered[r])
310            {
311                final double[] rowCosts = costs[r];
312
313                for (int c = 0; c < numCol; c++)
314                {
315                    if (!colsCovered[c])
316                    {
317                        final double v = rowCosts[c];
318
319                        if (v < min)
320                            min = v;
321                    }
322                }
323            }
324        }
325
326        // subtract minimum from uncovered elements
327        // and add it to double covered elements
328        for (int r = 0; r < numRow; r++)
329        {
330            final double[] rowCosts = costs[r];
331
332            if (rowsCovered[r])
333            {
334                for (int c = 0; c < numCol; c++)
335                    if (colsCovered[c])
336                        rowCosts[c] += min;
337            }
338            else
339            {
340                for (int c = 0; c < numCol; c++)
341                    if (!colsCovered[c])
342                        rowCosts[c] -= min;
343            }
344        }
345
346        step = 4;
347    }
348}