Skip to content

Commit 30090d9

Browse files
committed
Moved computation of clusters in online learning after all data has been sent
1 parent 5c2880d commit 30090d9

2 files changed

Lines changed: 52 additions & 68 deletions

File tree

clusters.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,23 @@ type Clusterer interface {
2727
type HardClusterer interface {
2828
Guesses() []HardCluster
2929

30-
Predict(observation []float64) (HardCluster, error)
30+
Predict(observation []float64) HardCluster
3131

32-
Online(observations chan []float64, done chan struct{}) chan int
32+
Online(observations chan []float64, done chan struct{}, callback func([]float64, int))
3333

34-
WithOnline(params Online) HardClusterer
34+
WithOnline(Online) HardClusterer
3535

3636
Clusterer
3737
}
3838

3939
type SoftClusterer interface {
4040
Guesses() []*SoftCluster
4141

42-
Predict(observation []float64) (*SoftCluster, error)
42+
Predict(observation []float64) *SoftCluster
4343

44-
Online(observations chan []float64, done chan struct{}) chan int
44+
Online(observations chan []float64, done chan struct{}, callback func())
4545

46-
WithOnline(params Online) SoftClusterer
46+
WithOnline(Online) SoftClusterer
4747

4848
Clusterer
4949
}

kmeans.go

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"math"
55
"math/rand"
66
"sync"
7-
"sync/atomic"
87
"time"
98

109
"gonum.org/v1/gonum/floats"
@@ -41,7 +40,7 @@ type kmeansClusterer struct {
4140
// Training set
4241
d [][]float64
4342

44-
// Computed clusters. Access is synchronized to accertain no incorrect predictions are made.
43+
// Computed clusters. Access is synchronized.
4544
mu sync.RWMutex
4645
c []HardCluster
4746
}
@@ -75,6 +74,10 @@ func (c *kmeansClusterer) WithOnline(o Online) HardClusterer {
7574
c.alpha = o.Alpha
7675
c.dimension = o.Dimension
7776

77+
c.d = make([][]float64, 0, 100)
78+
79+
c.initializeMeans()
80+
7881
return c
7982
}
8083

@@ -135,10 +138,7 @@ func (c *kmeansClusterer) Guesses() []HardCluster {
135138
return c.c
136139
}
137140

138-
func (c *kmeansClusterer) Predict(p []float64) (HardCluster, error) {
139-
c.mu.RLock()
140-
defer c.mu.RUnlock()
141-
141+
func (c *kmeansClusterer) Predict(p []float64) HardCluster {
142142
var (
143143
l HardCluster
144144
d float64
@@ -152,32 +152,30 @@ func (c *kmeansClusterer) Predict(p []float64) (HardCluster, error) {
152152
}
153153
}
154154

155-
return l, nil
155+
return l
156156
}
157157

158-
func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}) chan int {
159-
if c.alpha == 0 || c.dimension == 0 {
160-
return nil
161-
}
162-
163-
c.d = make([][]float64, 0, 100)
164-
165-
c.initializeMeans()
158+
func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}, callback func([]float64, int)) {
159+
c.mu.Lock()
166160

167161
var (
168-
r chan int = make(chan int)
169-
b []float64 = make([]float64, len(c.m[0]))
170-
k, l, f int = 0, len(c.m), len(c.m[0])
171-
m, n, am float64 = 0, 0, 1 - c.alpha
172-
s uint32
162+
l, f int = len(c.m), len(c.m[0])
163+
h float64 = 1 - c.alpha
173164
)
174165

166+
/* The first step of online learning is adjusting the centroids by finding the one closes to new data point
167+
* and modifying it's location using given alpha. Once the client quits sending new data, the actual clusters
168+
* are computed and the mutex is unlocked. */
169+
175170
go func() {
176171
for {
177172
select {
178173
case o := <-observations:
179-
m = squaredDistance(o, c.m[0])
180-
k = 0
174+
var (
175+
k int
176+
n float64
177+
m float64 = squaredDistance(o, c.m[0])
178+
)
181179

182180
for i := 1; i < l; i++ {
183181
if n = squaredDistance(o, c.m[i]); n < m {
@@ -186,60 +184,46 @@ func (c *kmeansClusterer) Online(observations chan []float64, done chan struct{}
186184
}
187185
}
188186

189-
for i := 0; i < f; i++ {
190-
b[i] = c.m[k][i]
191-
}
187+
go callback(o, k)
192188

193189
for i := 0; i < f; i++ {
194-
c.m[k][i] = c.alpha*o[i] + am*c.m[k][i]
190+
c.m[k][i] = c.alpha*o[i] + h*c.m[k][i]
195191
}
196192

197-
r <- k
198-
199-
/* Only trigger update if change of a centroid was
200-
* siginificant and goroutine limit is not reached
201-
*/
202-
if atomic.LoadUint32(&s) < GOROUTINE_THRESHOLD && !floats.EqualApprox(b, c.m[k], MEAN_THRESHOLD) {
203-
go func(p []float64) {
204-
c.mu.Lock()
205-
atomic.AddUint32(&s, 1)
206-
207-
var (
208-
n int
209-
d, m float64
210-
)
193+
c.d = append(c.d, o)
194+
case <-done:
195+
go func() {
196+
var (
197+
n int
198+
l int = len(c.d) / c.number
199+
d, m float64
200+
)
201+
202+
for i := 0; i < c.number; i++ {
203+
c.c[n] = make([][]float64, 0, l)
204+
}
211205

212-
c.d = append(c.d, p)
206+
for i := 0; i < len(c.d); i++ {
207+
m = c.distance(c.d[i], c.m[0])
208+
n = 0
213209

214-
for i := 0; i < c.number; i++ {
215-
c.c[i] = c.c[i][:0]
210+
for j := 1; j < c.number; j++ {
211+
if d = c.distance(c.d[i], c.m[j]); d < m {
212+
m = d
213+
n = j
214+
}
216215
}
217216

218-
for i := 0; i < len(c.d); i++ {
219-
m = c.distance(c.d[i], c.m[0])
220-
n = 0
221-
222-
for j := 1; j < c.number; j++ {
223-
if d = c.distance(c.d[i], c.m[j]); d < m {
224-
m = d
225-
n = j
226-
}
227-
}
217+
c.c[n] = append(c.c[n], c.d[i])
218+
}
228219

229-
c.c[n] = append(c.c[n], c.d[i])
230-
}
220+
c.mu.Unlock()
221+
}()
231222

232-
atomic.AddUint32(&s, ^uint32(0))
233-
c.mu.Unlock()
234-
}(o)
235-
}
236-
case <-done:
237223
return
238224
}
239225
}
240226
}()
241-
242-
return r
243227
}
244228

245229
// private

0 commit comments

Comments
 (0)