Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions matlab/BMatchingSolver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*!
* BMatchingSolverMex
* Bert Huang
*/

#include <iostream>
#include <math.h>
#include "mex.h"
#include "BMatchingLibrary.h"
#include "SparseMatrix.h"
#include "utils.h"

using namespace std;
using namespace bmatchingLibrary;

double ** getMatrix(const mxArray *pm) {
double ** A;
double * matrix = mxGetPr(pm);
int m = mxGetM(pm);
int n = mxGetN(pm);

if (m == 0)
return 0;

A = new double*[m];

for (int i = 0; i < m; i++)
A[i] = new double[n];

for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++)
A[i][j] = matrix[j*m+i];

return A;
}

void deleteMatrix(double ** A, int size) {
for (int i = 0; i < size; i++)
delete[](A[i]);
delete[](A);
}

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
int weightType = 0, d = 0, m = 0, n = 0, cacheSize = 0,
wsize = 0, xsize = 0, ysize = 0;
double ** W = 0, ** X = 0, ** Y = 0;
int * b, blen;
bool matrix = true;
bool verbose = false;


//(W, b, X, Y, weightType, cacheSize)
W = getMatrix(prhs[0]);

m = mxGetM(prhs[0]);
wsize = m;
n = mxGetN(prhs[0]);
double * bdouble = mxGetPr(prhs[1]);
blen = mxGetM(prhs[1]) + mxGetN(prhs[1]) - 1;
b = new int[blen];
for (int i = 0; i < blen; i++)
b[i] = round(bdouble[i]);

if (nrhs > 2) {
X = getMatrix(prhs[2]);

xsize = mxGetM(prhs[2]);

if (xsize > 0) {
m = mxGetM(prhs[2]);
d = mxGetN(prhs[2]);
matrix = false;
weightType = 1; // default weight type
if (nrhs > 3) {
Y = getMatrix(prhs[3]);
n = mxGetM(prhs[3]);
ysize = n;
} else
n = 0;
} else
X = 0;
}

if (nrhs > 4 && mxGetM(prhs[4]) > 0)
weightType = round(mxGetScalar(prhs[4]));
if (nrhs > 5 && mxGetM(prhs[4]) > 0)
cacheSize = round(mxGetScalar(prhs[5]));
else
cacheSize = round(2 * sqrt(m + n));

if (nrhs > 6)
verbose = mxGetScalar(prhs[6]) > 0;

SparseMatrix<bool> * solution;

// By default, perform no more than 100*(m+n) iterations
int maxIter = 100*(m+n);

if (blen == m) {
// unipartite
if (matrix)
solution = bMatchMatrixCache(m, W, b, cacheSize, maxIter, verbose);
else if (weightType == 1)
solution = bMatchEuclideanCache(m, d, X, b, cacheSize, maxIter, verbose);
else if (weightType == 2)
solution = bMatchInnerProductCache(m, d, X, b, cacheSize, maxIter, verbose);
else
mexErrMsgTxt("Unrecognized weight type");
} else {
// bipartite
if (matrix)
solution = bMatchBipartiteMatrixCache(m, n, W, b, b+m,
cacheSize, maxIter, verbose);
else if (weightType == 1)
solution = bMatchBipartiteEuclideanCache(m, n, d, X, Y, b,
b+m, cacheSize, maxIter, verbose);
else if (weightType == 2)
solution = bMatchBipartiteInnerProductCache(m, n, d, X, Y, b,
b+m, cacheSize, maxIter, verbose);
else
mexErrMsgTxt("Unrecognized weight type");
}

if (wsize > 0)
deleteMatrix(W, wsize);
if (xsize > 0)
deleteMatrix(X, xsize);
if (ysize > 0)
deleteMatrix(Y, ysize);

delete[](b);


int nnz = solution->getNNz();

mxArray * I = mxCreateDoubleMatrix(nnz, 1, mxREAL);
mxArray * J = mxCreateDoubleMatrix(nnz, 1, mxREAL);
mxArray * V = mxCreateDoubleMatrix(nnz, 1, mxREAL);

double * rows = new double[nnz];
double * cols = new double[nnz];
double * vals = new double[nnz];

for (int i=0; i < nnz; i++) {
rows[i] = solution->getRows()[i]+1;
cols[i] = solution->getCols()[i]+1;
vals[i] = 1.0;
}

delete(solution);

memcpy(mxGetPr(I), rows, nnz*sizeof(double));
memcpy(mxGetPr(J), cols, nnz*sizeof(double));
memcpy(mxGetPr(V), vals, nnz*sizeof(double));

delete[](rows);
delete[](cols);
delete[](vals);

mxArray * rhs[5];

rhs[0] = I;
rhs[1] = J;
rhs[2] = V;
rhs[3] = mxCreateDoubleScalar(blen);
rhs[4] = mxCreateDoubleScalar(blen);

mexCallMATLAB(1, plhs, 5, rhs, "sparse");

for (int i = 0; i < 5; i++)
mxDestroyArray(rhs[i]);



return;
}
90 changes: 90 additions & 0 deletions matlab/BMatchingSolverCmd.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
function [A, time] = BMatchingSolver(W, b, X, Y, weightType, cacheSize, flags)

% solves for the maximum weight b-matching
% W is the weight matrix (option 1)
% b is the vector of target degrees
% X is the first bipartition (option 2)
% Y is the second bipartition (option 2)
% weightType is 1 for negative Euclidean distance, 2 for inner product
% cacheSize is the size of the weight cache, default is 2*sqrt(m+n)
% calls mex version if mex version is compiled and in current path

[m, n] = size(W);
if exist('weightType', 'var') && ~isempty(weightType) && ismember(weightType, [1 2])
m = size(X,1);
n = size(Y,1);
end

global tmp_dir
if isempty(tmp_dir)
tmp_dir = '~/tmp';
end

global bmatchingsolver
if isempty(bmatchingsolver)
bmatchingsolver = '~/Dropbox/workspace/BMatchingSolver/Release/BMatchingSolver';
end

persistent problem_id;
if isempty(problem_id)
problem_id = uint32(randi(9999));
else
problem_id = uint32(mod(problem_id + randi(9999), 100000));
end

outFile = sprintf('%s/tmp_%d_output.txt', tmp_dir, problem_id);
degFile = sprintf('%s/tmp_%d_degrees.txt', tmp_dir, problem_id);
dlmwrite(degFile, b, 'precision', '%9.0f');

if (m == n && length(b) == m)
% assume unipartite
cmd = sprintf('%s -n %d -d %s -o %s', bmatchingsolver, m, ...
degFile, outFile);
N = m;
else
% bipartite
cmd = sprintf('%s -n %d --bipartite %d -d %s -o %s', bmatchingsolver, m+n, m, ...
degFile, outFile);
N = m+n;
end

if ~isempty(W)
weightFile = sprintf('%s/tmp_%d_weights.txt', tmp_dir, problem_id);
save(weightFile, 'W', '-ascii', '-double');

cmd = sprintf('%s -w %s', cmd, weightFile);
elseif exist('weightType', 'var')
dataFile = sprintf('%s/tmp_%d_data.txt', tmp_dir, problem_id);
data = [X; Y];
save(dataFile, 'data', '-ascii', '-double');

cmd = sprintf('%s -x %s -t %d -D %d', cmd, dataFile, weightType,...
size(data,2));
end

% add cache size parameter
if ~exist('cacheSize', 'var')
cacheSize = round(2*sqrt(m+n));
end
cmd = sprintf('%s -c %d', cmd, cacheSize);

if exist('flags', 'var')
cmd = sprintf('%s %s', cmd, flags);
end

tic;
system(cmd);
time = toc;

IJ = dlmread(outFile, ' ') + 1;
A = sparse(IJ(:,1), IJ(:,2), true(size(IJ,1),1), N, N);

delete(degFile);
if (exist('weightFile', 'var'))
delete(weightFile);
end
if (exist('dataFile', 'var'))
delete(dataFile);
end
delete(outFile);

60 changes: 60 additions & 0 deletions matlab/bdmatch_augment.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
function [W, X, Y, b, I, J] = bdmatch_augment(W, X, Y, lb, ub)
% function [W, X, Y, b] = bdmatch_mex(W, X, Y, lb, ub)
% creates augmented b-matching to solve bd-matching problem

b = ub;

auxcount = ub - lb;

maxaux = max(auxcount);

unused = maxaux - auxcount;

if ~isempty(W)
[m,n] = size(W);

W = [W zeros(m, maxaux); zeros(maxaux, n+maxaux)];
if (m == n && length(lb) == m)
% unipartite
for i = 1:m
W(i,end-unused(i)+1:end) = -inf;
W(end-unused(i)+1:end,i) = -inf;
end
b = [b(:); -ones(maxaux,1)];
I = 1:m;
J = 1:m;
else
% bipartite
for i = 1:m
W(i,end-unused(i)+1:end) = -inf;
end
for i = 1:n
W(end-unused(i+m)+1:end,i) = -inf;
end
b = [b(1:m); -ones(maxaux,1); b(m+1:end); -ones(maxaux,1)];

I = 1:m;
J = m+maxaux+1:m+maxaux+n;
end
elseif ~isempty(X)
[m,d] = size(X);
[n,d] = size(Y);

if any(unused(1:m) ~= unused(1)) || n > 0 && any(unused(m+1:end) ~= unused(m+1))
error('This script can only handle Euclidean or inner product problems with the same lb and ub in each bipartition');
end

if (n > 0)
% bipartite
X = [X; nan(auxcount(m+1), d)];
Y = [Y; nan(auxcount(1), d)];
b = [b(1:m); -ones(auxcount(m+1),1); b(m+1:end); -ones(auxcount(1),1)];
I = 1:m;
J = m+auxcount(m+1)+1:m+auxcount(m+1)+n;
else
X = [X; nan(auxcount(1), d)];
b = [b(:); -ones(auxcount(1),1)];
I = 1:m;
J = 1:m;
end
end
46 changes: 46 additions & 0 deletions matlab/lprelax.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
function [f, A, b] = lprelax(W, mask, lb, ub)

N = size(W,1);

if (N ~= length(lb))
[m,n] = size(W);
W = [sparse(m, m) W; W' sparse(n, n)];
mask = [sparse(m, m) mask; mask' sparse(n, n)];
N = n+m;
bipartite = true;
else
bipartite = false;
end

f = -nonzeros(triu(W.*mask));

[I,J] = find(triu(mask));


A = zeros(2*N, nnz(triu(mask)))';
b = zeros(2*N, 1);

for i=1:N
A(I==i | J==i, i) = 1;
b(i) = ub(i);

A(I==i | J==i, i+N) = -1;
b(i+N) = -lb(i);
end
A = A';

if (nargout==1)

options.Display = 'none';

x = linprog(f, A, b, [], [], zeros(size(f)), ones(size(f)), ...
zeros(size(f)), options);

X = sparse(I, J, x, N, N);

f = X+triu(X,1)';

if bipartite
f = f(1:m, m+1:end);
end
end
Loading