Skip to content

Commit 5c2880d

Browse files
committed
Added upper limit of cluster update goroutines for online learning
1 parent a6e261e commit 5c2880d

2 files changed

Lines changed: 49 additions & 46 deletions

File tree

clusters.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ import (
66

77
type DistanceFunc func(a, b []float64) float64
88

9+
type Online struct {
10+
Alpha float64
11+
Dimension int
12+
}
13+
914
type HardCluster [][]float64
1015

1116
type SoftCluster struct {
@@ -24,7 +29,9 @@ type HardClusterer interface {
2429

2530
Predict(observation []float64) (HardCluster, error)
2631

27-
Online(observations chan []float64, done chan struct{}) chan []HardCluster
32+
Online(observations chan []float64, done chan struct{}) chan int
33+
34+
WithOnline(params Online) HardClusterer
2835

2936
Clusterer
3037
}
@@ -34,7 +41,9 @@ type SoftClusterer interface {
3441

3542
Predict(observation []float64) (*SoftCluster, error)
3643

37-
Online(observations chan []float64, done chan struct{}) chan []*SoftCluster
44+
Online(observations chan []float64, done chan struct{}) chan int
45+
46+
WithOnline(params Online) SoftClusterer
3847

3948
Clusterer
4049
}

kmeans.go

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,18 @@ import (
44
"math"
55
"math/rand"
66
"sync"
7+
"sync/atomic"
78
"time"
89

910
"gonum.org/v1/gonum/floats"
1011
)
1112

1213
const (
13-
CHANGES_THRESHOLD = 2
14-
MEAN_THRESHOLD = 0.05
14+
CHANGES_THRESHOLD = 2
15+
GOROUTINE_THRESHOLD = 2
16+
MEAN_THRESHOLD = 0.05
1517
)
1618

17-
type Online struct {
18-
alpha float64
19-
dimension int
20-
}
21-
2219
type kmeansClusterer struct {
2320
iterations int
2421
number int
@@ -49,7 +46,7 @@ type kmeansClusterer struct {
4946
c []HardCluster
5047
}
5148

52-
func KmeansClusterer(iterations, clusters int, distance DistanceFunc, online ...Online) (HardClusterer, error) {
49+
func KmeansClusterer(iterations, clusters int, distance DistanceFunc) (HardClusterer, error) {
5350
if iterations < 1 {
5451
return nil, ErrZeroIterations
5552
}
@@ -67,23 +64,20 @@ func KmeansClusterer(iterations, clusters int, distance DistanceFunc, online ...
6764
}
6865
}
6966

70-
var o Online
71-
{
72-
if len(online) > 0 {
73-
o = online[0]
74-
}
75-
}
76-
7767
return &kmeansClusterer{
7868
iterations: iterations,
7969
number: clusters,
8070
distance: d,
81-
82-
alpha: o.alpha,
83-
dimension: o.dimension,
8471
}, nil
8572
}
8673

74+
func (c *kmeansClusterer) WithOnline(o Online) HardClusterer {
75+
c.alpha = o.Alpha
76+
c.dimension = o.Dimension
77+
78+
return c
79+
}
80+
8781
func (c *kmeansClusterer) Learn(data [][]float64) error {
8882
if len(data) == 0 {
8983
return ErrEmptySet
@@ -161,17 +155,21 @@ func (c *kmeansClusterer) Predict(p []float64) (HardCluster, error) {
161155
return l, nil
162156
}
163157

164-
func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}) chan []HardCluster {
158+
func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}) chan int {
159+
if c.alpha == 0 || c.dimension == 0 {
160+
return nil
161+
}
162+
165163
c.d = make([][]float64, 0, 100)
166164

167165
c.initializeMeans()
168166

169167
var (
170-
w sync.WaitGroup
171-
r chan []HardCluster = make(chan []HardCluster)
172-
b []float64 = make([]float64, len(c.m[0]))
173-
k, l, f int = 0, len(c.m), len(c.m[0])
174-
m, n, am float64 = 0, 0, 1 - c.alpha
168+
r chan int = make(chan int)
169+
b []float64 = make([]float64, len(c.m[0]))
170+
k, l, f int = 0, len(c.m), len(c.m[0])
171+
m, n, am float64 = 0, 0, 1 - c.alpha
172+
s uint32
175173
)
176174

177175
go func() {
@@ -182,7 +180,7 @@ func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}
182180
k = 0
183181

184182
for i := 1; i < l; i++ {
185-
if n = squaredDistance(o, c.m[1]); n < m {
183+
if n = squaredDistance(o, c.m[i]); n < m {
186184
m = n
187185
k = i
188186
}
@@ -196,48 +194,44 @@ func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}
196194
c.m[k][i] = c.alpha*o[i] + am*c.m[k][i]
197195
}
198196

199-
// Only trigger update if change of a centroid was siginificant, else send unchanged set
200-
if !floats.EqualApprox(b, c.m[k], MEAN_THRESHOLD) {
201-
go func(data [][]float64, p []float64) {
202-
w.Wait()
203-
204-
w.Add(1)
197+
r <- k
205198

199+
/* Only trigger update if change of a centroid was
200+
* siginificant and goroutine limit is not reached
201+
*/
202+
if atomic.LoadUint32(&s) < GOROUTINE_THRESHOLD && !floats.EqualApprox(b, c.m[k], MEAN_THRESHOLD) {
203+
go func(p []float64) {
206204
c.mu.Lock()
205+
atomic.AddUint32(&s, 1)
207206

208207
var (
209208
n int
210209
d, m float64
211210
)
212211

213-
data = append(data, p)
212+
c.d = append(c.d, p)
214213

215214
for i := 0; i < c.number; i++ {
216215
c.c[i] = c.c[i][:0]
217216
}
218217

219-
for i := 0; i < len(data); i++ {
220-
m = c.distance(data[i], c.m[0])
218+
for i := 0; i < len(c.d); i++ {
219+
m = c.distance(c.d[i], c.m[0])
221220
n = 0
222221

223222
for j := 1; j < c.number; j++ {
224-
if d = c.distance(data[i], c.m[j]); d < m {
223+
if d = c.distance(c.d[i], c.m[j]); d < m {
225224
m = d
226225
n = j
227226
}
228227
}
229228

230-
c.c[n] = append(c.c[n], data[i])
229+
c.c[n] = append(c.c[n], c.d[i])
231230
}
232231

232+
atomic.AddUint32(&s, ^uint32(0))
233233
c.mu.Unlock()
234-
235-
w.Done()
236-
237-
r <- c.c
238-
}(c.d, o)
239-
} else {
240-
r <- c.c
234+
}(o)
241235
}
242236
case <-done:
243237
return
@@ -266,7 +260,7 @@ func (c *kmeansClusterer) initializeMeans() {
266260

267261
for i := 0; i < c.number; i++ {
268262
c.m[i] = make([]float64, c.dimension)
269-
for j := 0; j < c.dimension; i++ {
263+
for j := 0; j < c.dimension; j++ {
270264
c.m[i][j] = 10 * (rand.Float64() - 0.5)
271265
}
272266
}

0 commit comments

Comments
 (0)