-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathfastcluster.py
More file actions
196 lines (160 loc) · 5.5 KB
/
fastcluster.py
File metadata and controls
196 lines (160 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python
import os, sys
import ctypes
import tables
import numpy as np
import numpy.random as npr
libfastann = ctypes.CDLL('libfastann.so')
try:
libfastcluster = ctypes.CDLL('libfastcluster.so')
except OSError, e:
libfastcluster = ctypes.CDLL('./libfastcluster.so')
def get_suffix(dtype):
return {np.dtype('u1') : 'c',
np.dtype('f4') : 's',
np.dtype('f8') : 'd'}[dtype]
class nn_obj_exact_builder(object):
def __init__(self, dtype):
self.dtype = dtype
self.suffix = get_suffix(dtype)
def build_nn_obj(self, p, clusters, K, D):
ptr = getattr(libfastann, "fastann_nn_obj_build_exact_" + self.suffix)(clusters, K, D)
return ptr
class nn_obj_approx_builder(object):
def __init__(self, dtype, ntrees, nchecks):
self.dtype = dtype
self.suffix = get_suffix(dtype)
self.ntrees = ntrees
self.nchecks = nchecks
def build_nn_obj(self, p, clusters, K, D):
ptr = getattr(libfastann, "fastann_nn_obj_build_kdtree_" + self.suffix)\
(ctypes.c_void_p(clusters), ctypes.c_uint(K), ctypes.c_uint(D),
ctypes.c_uint(self.ntrees), ctypes.c_uint(self.nchecks))
return ptr
class hdf5_wrap(object):
def __init__(self, pnts_obj, dt):
self.dt = dt
self.pnts_obj = pnts_obj
def read_rows(self, p, l, r, out):
pnts = self.pnts_obj[l:r].astype(self.dt)
pnts_ptr = pnts.ctypes.data_as(ctypes.c_void_p)
ctypes.memmove(ctypes.c_void_p(out), pnts_ptr, pnts.dtype.itemsize*(r-l)*pnts.shape[1])
def kmeans(clst_fn,
pnts_fn,
K,
niters = 30,
approx = True,
ntrees = 8,
nchecks = 784,
checkpoint = True,
seed = 42):
"""
Runs the distributed approximate k-means algorithm.
Params
------
clst_fn : string
HDF5 filename for the cluster output
pnts_fn : string
HDF5 filename for the points to cluster
K : int
Number of clusters
niters : int (30)
Number of iterations
approx : bool (True)
Exact or approximate nn
ntrees : int (8)
Size of the k-d forest
nchecks : int (768)
Number of point distances to compute per query
checkpoint : bool (True)
Whether to checkpoint
seed : int (42)
Random seed
"""
errc = libfastcluster.safe_init()
if errc: raise RuntimeError, 'problem with mpi_init'
npr.seed(seed)
# Probe for datatype and dimensionality
pnts_fobj = tables.open_file(pnts_fn, 'r')
for pnts_obj in pnts_fobj.walk_nodes('/', classname = 'Array'):
break
N = pnts_obj.shape[0]
D = pnts_obj.shape[1]
dtype = pnts_obj.atom.dtype
if dtype not in [np.dtype('u1'), np.dtype('f4'), np.dtype('f8')]:
raise TypeError, 'Datatype %s not supported' % dtype
if dtype == np.dtype('u1'):
dtype = np.dtype('f4')
if approx:
nn_builder = nn_obj_approx_builder(dtype, ntrees, nchecks)
else:
nn_builder = nn_obj_exact_builder(dtype)
pnt_loader = hdf5_wrap(pnts_obj, dtype)
# Callbacks
LOAD_ROWS_FUNC = \
ctypes.CFUNCTYPE(None, ctypes.c_void_p, ctypes.c_uint, ctypes.c_uint,
ctypes.c_void_p)
NN_BUILDER_FUNC = \
ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_uint, ctypes.c_uint)
load_rows_func = LOAD_ROWS_FUNC(pnt_loader.read_rows)
nn_builder_func = NN_BUILDER_FUNC(nn_builder.build_nn_obj)
# Space for the clusters
clusters = np.empty((K, D), dtype = dtype)
clusters_ptr = clusters.ctypes.data_as(ctypes.c_void_p)
# Initialize the clusters
if libfastcluster.rank() == 0:
sys.stdout.write('Sampling cluster centers...')
sys.stdout.flush()
pnts_inds = np.arange(N)
npr.shuffle(pnts_inds)
pnts_inds = pnts_inds[:K]
pnts_inds = np.sort(pnts_inds)
for i,ind in enumerate(pnts_inds):
clusters[i] = pnts_obj[ind]
if not (i%(K/100)):
sys.stdout.write('\r[%07d/%07d]' % (i, K))
sys.stdout.flush()
sys.stdout.write('Done...')
sys.stdout.flush()
if checkpoint:
chkpnt_fn = clst_fn + '.chkpnt'
else:
chkpnt_fn = ''
getattr(libfastcluster, "kmeans_" + get_suffix(dtype))\
(load_rows_func,
ctypes.c_void_p(0),
nn_builder_func,
ctypes.c_void_p(0),
clusters_ptr,
ctypes.c_uint(N),
ctypes.c_uint(D),
ctypes.c_uint(K),
ctypes.c_uint(niters),
ctypes.c_int(0),
ctypes.c_char_p(chkpnt_fn))
# All done, save the clusters
if libfastcluster.rank() == 0:
filters = tables.Filters(complevel = 1, complib = 'zlib')
clst_fobj = tables.open_file(clst_fn, 'w')
clst_obj = \
clst_fobj.create_carray(clst_fobj.root, 'clusters',
tables.Atom.from_dtype(dtype), clusters.shape,
filters = filters)
clst_obj[:] = clusters
clst_fobj.close()
if chkpnt_fn != '':
try:
os.remove(chkpnt_fn)
except OSError, e:
pass
if __name__ == "__main__":
K = 1000003
niters = 10
ntrees = 8
libfastcluster.safe_init()
libfastcluster.barrier()
kmeans('./clst_1M_iter10.h5',
'./pnts_float_rootsift.h5',
K,
niters=niters, ntrees=ntrees)