11package clusters
22
33import (
4- "fmt"
54 "math"
65 "math/rand"
76 "sync"
@@ -10,7 +9,7 @@ import (
109)
1110
1211const (
13- CHANGES_THRESHOLD = 5
12+ CHANGES_THRESHOLD = 2
1413)
1514
1615type kmeansClusterer struct {
@@ -23,20 +22,23 @@ type kmeansClusterer struct {
2322 distance DistanceFunc
2423
2524 // Mapping from training set points to cluster numbers.
26- clustered map [int ]int
25+ pc map [int ]int
2726
2827 // Mapping from clusters' numbers to set of points they contain.
29- points map [int ][]int
28+ cp map [int ][]int
29+
30+ // Mapping from clusters' numbers to their means
31+ means map [int ][]float64
3032
3133 // Training set
3234 dataset [][]float64
3335
3436 // Computed clusters. Access is synchronized to accertain no incorrect predictions are made.
3537 sync.RWMutex
36- clusters []* Cluster
38+ clusters []HardCluster
3739}
3840
39- func KmeansClusterer (iterations , clusters int , distance DistanceFunc ) (Clusterer , error ) {
41+ func KmeansClusterer (iterations , clusters int , distance DistanceFunc ) (HardClusterer , error ) {
4042 if iterations < 1 {
4143 return nil , ErrZeroIterations
4244 }
@@ -70,8 +72,9 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
7072
7173 c .dataset = data
7274
73- c .clustered = make (map [int ]int , len (data ))
74- c .points = make (map [int ][]int , c .number )
75+ c .pc = make (map [int ]int , len (data ))
76+ c .cp = make (map [int ][]int , c .number )
77+ c .means = make (map [int ][]float64 , c .number )
7578
7679 c .counter = 0
7780 c .threshold = CHANGES_THRESHOLD
@@ -80,7 +83,7 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
8083
8184 c .initializeClusters ()
8285
83- for i := 0 ; i < c .iterations && c .shouldStop (); i ++ {
86+ for i := 0 ; i < c .iterations && c .notConverged (); i ++ {
8487 c .run ()
8588 }
8689
@@ -93,14 +96,10 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
9396 go func (n int ) {
9497 defer wg .Done ()
9598
96- l := len (c .points [c .clusters [n ].number ])
97-
98- c .clusters [n ].data = make ([][]float64 , l )
99-
100- fmt .Printf ("Cluster no. %02d centroid: %v\n " , c .clusters [n ].number , c .clusters [n ].mean )
99+ c .clusters [n ] = make ([][]float64 , len (c .cp [n ]))
101100
102- for k := 0 ; k < l ; k ++ {
103- c .clusters [n ]. data [k ] = c .dataset [c .points [ c . clusters [ n ]. number ][k ]]
101+ for k := 0 ; k < len ( c . cp [ n ]) ; k ++ {
102+ c.clusters [n ][k ] = c .dataset [c.cp [ n ][k ]]
104103 }
105104 }(j )
106105 }
@@ -109,13 +108,13 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
109108
110109 c .Unlock ()
111110
112- c .clustered = map [int ]int {}
113- c .points = map [int ][]int {}
111+ c .pc = map [int ]int {}
112+ c .cp = map [int ][]int {}
114113
115114 return nil
116115}
117116
118- func (c * kmeansClusterer ) Compute () ([]* Cluster , error ) {
117+ func (c * kmeansClusterer ) Clusters () ([]HardCluster , error ) {
119118 c .RLock ()
120119 defer c .RUnlock ()
121120
@@ -126,12 +125,12 @@ func (c *kmeansClusterer) Compute() ([]*Cluster, error) {
126125 return c .clusters , nil
127126}
128127
129- func (c * kmeansClusterer ) Predict (p []float64 ) (* Cluster , error ) {
128+ func (c * kmeansClusterer ) Predict (p []float64 ) (HardCluster , error ) {
130129 c .RLock ()
131130 defer c .RUnlock ()
132131
133132 if c .clusters == nil {
134- return nil , ErrEmptyClusters
133+ return HardCluster {} , ErrEmptyClusters
135134 }
136135
137136 var (
@@ -141,7 +140,7 @@ func (c *kmeansClusterer) Predict(p []float64) (*Cluster, error) {
141140 )
142141
143142 for i := 0 ; i < len (c .clusters ); i ++ {
144- if d = c .distance (p , c .clusters [i ]. mean ); d < m {
143+ if d = c .distance (p , c .means [i ]); d < m {
145144 m = d
146145 l = i
147146 }
@@ -150,18 +149,9 @@ func (c *kmeansClusterer) Predict(p []float64) (*Cluster, error) {
150149 return c .clusters [l ], nil
151150}
152151
153- func (c * kmeansClusterer ) PredictFunc () PredictFunc {
154- c .RLock ()
155- defer c .RUnlock ()
156-
157- return func (p []float64 ) (* Cluster , error ) {
158- return c .Predict (p )
159- }
160- }
161-
162- func (c * kmeansClusterer ) Online (observations chan []float64 , done chan bool ) chan []* Cluster {
152+ func (c * kmeansClusterer ) Online (observations chan []float64 , done chan bool ) chan []HardCluster {
163153 var (
164- r = make (chan []* Cluster )
154+ r = make (chan []HardCluster )
165155 )
166156
167157 go func () {
@@ -179,33 +169,29 @@ func (c *kmeansClusterer) Online(observations chan []float64, done chan bool) ch
179169
180170// private
181171func (c * kmeansClusterer ) initializeClusters () {
182- c .clusters = make ([]* Cluster , c .number )
172+ c .clusters = make ([]HardCluster , c .number )
183173
184174 for i := 0 ; i < c .number ; i ++ {
185- c .clusters [i ] = & Cluster {
186- number : i ,
187- mean : c .dataset [rand .Intn (len (c .dataset )- 1 )],
188- }
175+ c .means [i ] = c .dataset [rand .Intn (len (c .dataset )- 1 )]
189176 }
190177}
191178
192179func (c * kmeansClusterer ) run () error {
193180 for i := 0 ; i < len (c .clusters ); i ++ {
194- var l = len (c .points [c .clusters [i ].number ])
195-
181+ var l = len (c .cp [i ])
196182 if l == 0 {
197183 continue
198184 }
199185
200186 var m = make ([]float64 , len (c .dataset [0 ]))
201187 for j := 0 ; j < l ; j ++ {
202- floats .Add (m , c .dataset [c .points [ c . clusters [ i ]. number ][j ]])
188+ floats .Add (m , c .dataset [c.cp [ i ][j ]])
203189 }
204190
205191 floats .Scale (1 / float64 (l ), m )
206192
207- c .clusters [i ]. mean = m
208- c .points [ c . clusters [ i ]. number ] = []int {}
193+ c .means [i ] = m
194+ c .cp [ i ] = []int {}
209195 }
210196
211197 for i := 0 ; i < len (c .dataset ); i ++ {
@@ -216,28 +202,28 @@ func (c *kmeansClusterer) run() error {
216202 )
217203
218204 for j := 0 ; j < len (c .clusters ); j ++ {
219- if d = c .distance (c .dataset [i ], c .clusters [j ]. mean ); d < m {
205+ if d = c .distance (c .dataset [i ], c .means [j ]); d < m {
220206 m = d
221- n = c . clusters [ j ]. number
207+ n = j
222208 }
223209 }
224210
225- if v , ok := c .clustered [i ]; ok {
211+ if v , ok := c .pc [i ]; ok {
226212 if v != n {
227213 c .changes ++
228214 }
229215 } else {
230216 c .changes ++
231217 }
232218
233- c .clustered [i ] = n
234- c .points [n ] = append (c .points [n ], i )
219+ c .pc [i ] = n
220+ c .cp [n ] = append (c .cp [n ], i )
235221 }
236222
237223 return nil
238224}
239225
240- func (c * kmeansClusterer ) shouldStop () bool {
226+ func (c * kmeansClusterer ) notConverged () bool {
241227 if c .counter == c .threshold {
242228 return false
243229 }
0 commit comments