Skip to content

Commit dee95e5

Browse files
committed
Changed interfaces to accomodation distintions of soft and hard clusters
1 parent 105c888 commit dee95e5

3 files changed

Lines changed: 70 additions & 69 deletions

File tree

clusters.go

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,39 @@ import (
44
"gonum.org/v1/gonum/floats"
55
)
66

7-
type PredictFunc func(observation []float64) (*Cluster, error)
8-
97
type DistanceFunc func(a, b []float64) float64
108

11-
type Cluster struct {
12-
number int
13-
mean []float64
14-
data [][]float64
15-
}
9+
type HardCluster [][]float64
1610

17-
func (c *Cluster) Number() int {
18-
return c.number
11+
type SoftCluster struct {
12+
sizes []int
13+
data []struct {
14+
probabilities, observation []float64
15+
}
1916
}
2017

21-
func (c *Cluster) Size() int {
22-
return len(c.data)
18+
type Clusterer interface {
19+
Learn(data [][]float64) error
2320
}
2421

25-
func (c *Cluster) Data() [][]float64 {
26-
return c.data
27-
}
22+
type HardClusterer interface {
23+
Clusters() ([]HardCluster, error)
2824

29-
type Clusterer interface {
30-
Learn(data [][]float64) error
25+
Predict(observation []float64) (HardCluster, error)
26+
27+
Online(observations chan []float64, done chan bool) chan []HardCluster
28+
29+
Clusterer
30+
}
3131

32-
Predict(observation []float64) (*Cluster, error)
32+
type SoftClusterer interface {
33+
Clusters() ([]*SoftCluster, error)
3334

34-
PredictFunc() PredictFunc
35+
Predict(observation []float64) (*SoftCluster, error)
3536

36-
Compute() ([]*Cluster, error)
37+
Online(observations chan []float64, done chan bool) chan []*SoftCluster
3738

38-
Online(observations chan []float64, done chan bool) chan []*Cluster
39+
Clusterer
3940
}
4041

4142
var (

em.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package clusters
2+
3+
import (
4+
"sync"
5+
)
6+
7+
type emClusterer struct {
8+
// Training set
9+
dataset [][]float64
10+
11+
// Computed clusters. Access is synchronized to accertain no incorrect predictions are made.
12+
sync.RWMutex
13+
clusters []*SoftCluster
14+
}

kmeans.go

Lines changed: 35 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package clusters
22

33
import (
4-
"fmt"
54
"math"
65
"math/rand"
76
"sync"
@@ -10,7 +9,7 @@ import (
109
)
1110

1211
const (
13-
CHANGES_THRESHOLD = 5
12+
CHANGES_THRESHOLD = 2
1413
)
1514

1615
type kmeansClusterer struct {
@@ -23,20 +22,23 @@ type kmeansClusterer struct {
2322
distance DistanceFunc
2423

2524
// Mapping from training set points to cluster numbers.
26-
clustered map[int]int
25+
pc map[int]int
2726

2827
// Mapping from clusters' numbers to set of points they contain.
29-
points map[int][]int
28+
cp map[int][]int
29+
30+
// Mapping from clusters' numbers to their means
31+
means map[int][]float64
3032

3133
// Training set
3234
dataset [][]float64
3335

3436
// Computed clusters. Access is synchronized to accertain no incorrect predictions are made.
3537
sync.RWMutex
36-
clusters []*Cluster
38+
clusters []HardCluster
3739
}
3840

39-
func KmeansClusterer(iterations, clusters int, distance DistanceFunc) (Clusterer, error) {
41+
func KmeansClusterer(iterations, clusters int, distance DistanceFunc) (HardClusterer, error) {
4042
if iterations < 1 {
4143
return nil, ErrZeroIterations
4244
}
@@ -70,8 +72,9 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
7072

7173
c.dataset = data
7274

73-
c.clustered = make(map[int]int, len(data))
74-
c.points = make(map[int][]int, c.number)
75+
c.pc = make(map[int]int, len(data))
76+
c.cp = make(map[int][]int, c.number)
77+
c.means = make(map[int][]float64, c.number)
7578

7679
c.counter = 0
7780
c.threshold = CHANGES_THRESHOLD
@@ -80,7 +83,7 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
8083

8184
c.initializeClusters()
8285

83-
for i := 0; i < c.iterations && c.shouldStop(); i++ {
86+
for i := 0; i < c.iterations && c.notConverged(); i++ {
8487
c.run()
8588
}
8689

@@ -93,14 +96,10 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
9396
go func(n int) {
9497
defer wg.Done()
9598

96-
l := len(c.points[c.clusters[n].number])
97-
98-
c.clusters[n].data = make([][]float64, l)
99-
100-
fmt.Printf("Cluster no. %02d centroid: %v\n", c.clusters[n].number, c.clusters[n].mean)
99+
c.clusters[n] = make([][]float64, len(c.cp[n]))
101100

102-
for k := 0; k < l; k++ {
103-
c.clusters[n].data[k] = c.dataset[c.points[c.clusters[n].number][k]]
101+
for k := 0; k < len(c.cp[n]); k++ {
102+
c.clusters[n][k] = c.dataset[c.cp[n][k]]
104103
}
105104
}(j)
106105
}
@@ -109,13 +108,13 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
109108

110109
c.Unlock()
111110

112-
c.clustered = map[int]int{}
113-
c.points = map[int][]int{}
111+
c.pc = map[int]int{}
112+
c.cp = map[int][]int{}
114113

115114
return nil
116115
}
117116

118-
func (c *kmeansClusterer) Compute() ([]*Cluster, error) {
117+
func (c *kmeansClusterer) Clusters() ([]HardCluster, error) {
119118
c.RLock()
120119
defer c.RUnlock()
121120

@@ -126,12 +125,12 @@ func (c *kmeansClusterer) Compute() ([]*Cluster, error) {
126125
return c.clusters, nil
127126
}
128127

129-
func (c *kmeansClusterer) Predict(p []float64) (*Cluster, error) {
128+
func (c *kmeansClusterer) Predict(p []float64) (HardCluster, error) {
130129
c.RLock()
131130
defer c.RUnlock()
132131

133132
if c.clusters == nil {
134-
return nil, ErrEmptyClusters
133+
return HardCluster{}, ErrEmptyClusters
135134
}
136135

137136
var (
@@ -141,7 +140,7 @@ func (c *kmeansClusterer) Predict(p []float64) (*Cluster, error) {
141140
)
142141

143142
for i := 0; i < len(c.clusters); i++ {
144-
if d = c.distance(p, c.clusters[i].mean); d < m {
143+
if d = c.distance(p, c.means[i]); d < m {
145144
m = d
146145
l = i
147146
}
@@ -150,18 +149,9 @@ func (c *kmeansClusterer) Predict(p []float64) (*Cluster, error) {
150149
return c.clusters[l], nil
151150
}
152151

153-
func (c *kmeansClusterer) PredictFunc() PredictFunc {
154-
c.RLock()
155-
defer c.RUnlock()
156-
157-
return func(p []float64) (*Cluster, error) {
158-
return c.Predict(p)
159-
}
160-
}
161-
162-
func (c *kmeansClusterer) Online(observations chan []float64, done chan bool) chan []*Cluster {
152+
func (c *kmeansClusterer) Online(observations chan []float64, done chan bool) chan []HardCluster {
163153
var (
164-
r = make(chan []*Cluster)
154+
r = make(chan []HardCluster)
165155
)
166156

167157
go func() {
@@ -179,33 +169,29 @@ func (c *kmeansClusterer) Online(observations chan []float64, done chan bool) ch
179169

180170
// private
181171
func (c *kmeansClusterer) initializeClusters() {
182-
c.clusters = make([]*Cluster, c.number)
172+
c.clusters = make([]HardCluster, c.number)
183173

184174
for i := 0; i < c.number; i++ {
185-
c.clusters[i] = &Cluster{
186-
number: i,
187-
mean: c.dataset[rand.Intn(len(c.dataset)-1)],
188-
}
175+
c.means[i] = c.dataset[rand.Intn(len(c.dataset)-1)]
189176
}
190177
}
191178

192179
func (c *kmeansClusterer) run() error {
193180
for i := 0; i < len(c.clusters); i++ {
194-
var l = len(c.points[c.clusters[i].number])
195-
181+
var l = len(c.cp[i])
196182
if l == 0 {
197183
continue
198184
}
199185

200186
var m = make([]float64, len(c.dataset[0]))
201187
for j := 0; j < l; j++ {
202-
floats.Add(m, c.dataset[c.points[c.clusters[i].number][j]])
188+
floats.Add(m, c.dataset[c.cp[i][j]])
203189
}
204190

205191
floats.Scale(1/float64(l), m)
206192

207-
c.clusters[i].mean = m
208-
c.points[c.clusters[i].number] = []int{}
193+
c.means[i] = m
194+
c.cp[i] = []int{}
209195
}
210196

211197
for i := 0; i < len(c.dataset); i++ {
@@ -216,28 +202,28 @@ func (c *kmeansClusterer) run() error {
216202
)
217203

218204
for j := 0; j < len(c.clusters); j++ {
219-
if d = c.distance(c.dataset[i], c.clusters[j].mean); d < m {
205+
if d = c.distance(c.dataset[i], c.means[j]); d < m {
220206
m = d
221-
n = c.clusters[j].number
207+
n = j
222208
}
223209
}
224210

225-
if v, ok := c.clustered[i]; ok {
211+
if v, ok := c.pc[i]; ok {
226212
if v != n {
227213
c.changes++
228214
}
229215
} else {
230216
c.changes++
231217
}
232218

233-
c.clustered[i] = n
234-
c.points[n] = append(c.points[n], i)
219+
c.pc[i] = n
220+
c.cp[n] = append(c.cp[n], i)
235221
}
236222

237223
return nil
238224
}
239225

240-
func (c *kmeansClusterer) shouldStop() bool {
226+
func (c *kmeansClusterer) notConverged() bool {
241227
if c.counter == c.threshold {
242228
return false
243229
}

0 commit comments

Comments
 (0)