@@ -4,21 +4,18 @@ import (
44 "math"
55 "math/rand"
66 "sync"
7+ "sync/atomic"
78 "time"
89
910 "gonum.org/v1/gonum/floats"
1011)
1112
1213const (
13- CHANGES_THRESHOLD = 2
14- MEAN_THRESHOLD = 0.05
14+ CHANGES_THRESHOLD = 2
15+ GOROUTINE_THRESHOLD = 2
16+ MEAN_THRESHOLD = 0.05
1517)
1618
17- type Online struct {
18- alpha float64
19- dimension int
20- }
21-
2219type kmeansClusterer struct {
2320 iterations int
2421 number int
@@ -49,7 +46,7 @@ type kmeansClusterer struct {
4946 c []HardCluster
5047}
5148
52- func KmeansClusterer (iterations , clusters int , distance DistanceFunc , online ... Online ) (HardClusterer , error ) {
49+ func KmeansClusterer (iterations , clusters int , distance DistanceFunc ) (HardClusterer , error ) {
5350 if iterations < 1 {
5451 return nil , ErrZeroIterations
5552 }
@@ -67,23 +64,20 @@ func KmeansClusterer(iterations, clusters int, distance DistanceFunc, online ...
6764 }
6865 }
6966
70- var o Online
71- {
72- if len (online ) > 0 {
73- o = online [0 ]
74- }
75- }
76-
7767 return & kmeansClusterer {
7868 iterations : iterations ,
7969 number : clusters ,
8070 distance : d ,
81-
82- alpha : o .alpha ,
83- dimension : o .dimension ,
8471 }, nil
8572}
8673
74+ func (c * kmeansClusterer ) WithOnline (o Online ) HardClusterer {
75+ c .alpha = o .Alpha
76+ c .dimension = o .Dimension
77+
78+ return c
79+ }
80+
8781func (c * kmeansClusterer ) Learn (data [][]float64 ) error {
8882 if len (data ) == 0 {
8983 return ErrEmptySet
@@ -161,17 +155,21 @@ func (c *kmeansClusterer) Predict(p []float64) (HardCluster, error) {
161155 return l , nil
162156}
163157
164- func (c * kmeansClusterer ) Online (observations chan []float64 , done chan struct {}) chan []HardCluster {
158+ func (c * kmeansClusterer ) Online (observations chan []float64 , done chan struct {}) chan int {
159+ if c .alpha == 0 || c .dimension == 0 {
160+ return nil
161+ }
162+
165163 c .d = make ([][]float64 , 0 , 100 )
166164
167165 c .initializeMeans ()
168166
169167 var (
170- w sync. WaitGroup
171- r chan [] HardCluster = make (chan [] HardCluster )
172- b [] float64 = make ([] float64 , len (c .m [0 ]) )
173- k , l , f int = 0 , len ( c . m ), len ( c . m [ 0 ])
174- m , n , am float64 = 0 , 0 , 1 - c . alpha
168+ r chan int = make ( chan int )
169+ b [] float64 = make ([] float64 , len ( c . m [ 0 ]) )
170+ k , l , f int = 0 , len ( c . m ) , len (c .m [0 ])
171+ m , n , am float64 = 0 , 0 , 1 - c . alpha
172+ s uint32
175173 )
176174
177175 go func () {
@@ -182,7 +180,7 @@ func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}
182180 k = 0
183181
184182 for i := 1 ; i < l ; i ++ {
185- if n = squaredDistance (o , c .m [1 ]); n < m {
183+ if n = squaredDistance (o , c .m [i ]); n < m {
186184 m = n
187185 k = i
188186 }
@@ -196,48 +194,44 @@ func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}
196194 c.m [k ][i ] = c .alpha * o [i ] + am * c.m [k ][i ]
197195 }
198196
199- // Only trigger update if change of a centroid was siginificant, else send unchanged set
200- if ! floats .EqualApprox (b , c .m [k ], MEAN_THRESHOLD ) {
201- go func (data [][]float64 , p []float64 ) {
202- w .Wait ()
203-
204- w .Add (1 )
197+ r <- k
205198
199+ /* Only trigger update if change of a centroid was
200+ * siginificant and goroutine limit is not reached
201+ */
202+ if atomic .LoadUint32 (& s ) < GOROUTINE_THRESHOLD && ! floats .EqualApprox (b , c .m [k ], MEAN_THRESHOLD ) {
203+ go func (p []float64 ) {
206204 c .mu .Lock ()
205+ atomic .AddUint32 (& s , 1 )
207206
208207 var (
209208 n int
210209 d , m float64
211210 )
212211
213- data = append (data , p )
212+ c . d = append (c . d , p )
214213
215214 for i := 0 ; i < c .number ; i ++ {
216215 c .c [i ] = c .c [i ][:0 ]
217216 }
218217
219- for i := 0 ; i < len (data ); i ++ {
220- m = c .distance (data [i ], c .m [0 ])
218+ for i := 0 ; i < len (c . d ); i ++ {
219+ m = c .distance (c . d [i ], c .m [0 ])
221220 n = 0
222221
223222 for j := 1 ; j < c .number ; j ++ {
224- if d = c .distance (data [i ], c .m [j ]); d < m {
223+ if d = c .distance (c . d [i ], c .m [j ]); d < m {
225224 m = d
226225 n = j
227226 }
228227 }
229228
230- c .c [n ] = append (c .c [n ], data [i ])
229+ c .c [n ] = append (c .c [n ], c . d [i ])
231230 }
232231
232+ atomic .AddUint32 (& s , ^ uint32 (0 ))
233233 c .mu .Unlock ()
234-
235- w .Done ()
236-
237- r <- c .c
238- }(c .d , o )
239- } else {
240- r <- c .c
234+ }(o )
241235 }
242236 case <- done :
243237 return
@@ -266,7 +260,7 @@ func (c *kmeansClusterer) initializeMeans() {
266260
267261 for i := 0 ; i < c .number ; i ++ {
268262 c .m [i ] = make ([]float64 , c .dimension )
269- for j := 0 ; j < c .dimension ; i ++ {
263+ for j := 0 ; j < c .dimension ; j ++ {
270264 c.m [i ][j ] = 10 * (rand .Float64 () - 0.5 )
271265 }
272266 }
0 commit comments