Skip to content

Commit d84f880

Browse files
committed
initial commit
1 parent 7037ffb commit d84f880

3 files changed

Lines changed: 308 additions & 0 deletions

File tree

clusters.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package clusters
2+
3+
import (
4+
"gonum.org/v1/gonum/floats"
5+
)
6+
7+
type PredictFunc func(observation []float64) (*Cluster, error)
8+
9+
type DistanceFunc func(a, b []float64) float64
10+
11+
type Cluster struct {
12+
number int
13+
mean []float64
14+
data [][]float64
15+
}
16+
17+
func (c *Cluster) Number() int {
18+
return c.number
19+
}
20+
21+
func (c *Cluster) Size() int {
22+
return len(c.data)
23+
}
24+
25+
func (c *Cluster) Data() [][]float64 {
26+
return c.data
27+
}
28+
29+
type Clusterer interface {
30+
Learn(data [][]float64) error
31+
32+
Predict(observation []float64) (*Cluster, error)
33+
34+
PredictFunc() PredictFunc
35+
36+
Compute() ([]*Cluster, error)
37+
38+
Online(observations chan []float64, done chan bool) chan []*Cluster
39+
}
40+
41+
var (
42+
EuclideanDistance = func(a, b []float64) float64 {
43+
return floats.Distance(a, b, 2)
44+
}
45+
)

errors.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package clusters
2+
3+
import "errors"
4+
5+
var (
6+
ErrDimensionMismatch = errors.New("Vectors have different dimension")
7+
ErrEmptySet = errors.New("Empty training set")
8+
ErrEmptyClusters = errors.New("Empty clusters")
9+
ErrZeroIterations = errors.New("Number of iterations cannot be less than 1")
10+
ErrZeroClusters = errors.New("Number of clusters cannot be less than 1")
11+
)

kmeans.go

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
package clusters
2+
3+
import (
4+
"fmt"
5+
"math"
6+
"math/rand"
7+
"sync"
8+
9+
"gonum.org/v1/gonum/floats"
10+
)
11+
12+
const (
13+
CHANGES_THRESHOLD = 5
14+
)
15+
16+
type kmeansClusterer struct {
17+
iterations int
18+
number int
19+
20+
// Variables keeping count of changes of points' membership every iteration. User as a stopping condition.
21+
changes, oldchanges, counter, threshold int
22+
23+
distance DistanceFunc
24+
25+
// Mapping from training set points to cluster numbers.
26+
clustered map[int]int
27+
28+
// Mapping from clusters' numbers to set of points they contain.
29+
points map[int][]int
30+
31+
// Training set
32+
dataset [][]float64
33+
34+
// Computed clusters. Access is synchronized to accertain no incorrect predictions are made.
35+
sync.RWMutex
36+
clusters []*Cluster
37+
}
38+
39+
func KmeansClusterer(iterations, clusters int, distance DistanceFunc) (Clusterer, error) {
40+
if iterations < 1 {
41+
return nil, ErrZeroIterations
42+
}
43+
44+
if clusters < 1 {
45+
return nil, ErrZeroClusters
46+
}
47+
48+
var d DistanceFunc
49+
{
50+
if distance != nil {
51+
d = distance
52+
} else {
53+
d = EuclideanDistance
54+
}
55+
}
56+
57+
return &kmeansClusterer{
58+
iterations: iterations,
59+
number: clusters,
60+
distance: d,
61+
}, nil
62+
}
63+
64+
func (c *kmeansClusterer) Learn(data [][]float64) error {
65+
if len(data) == 0 {
66+
return ErrEmptySet
67+
}
68+
69+
c.Lock()
70+
71+
c.dataset = data
72+
73+
c.clustered = make(map[int]int, len(data))
74+
c.points = make(map[int][]int, c.number)
75+
76+
c.counter = 0
77+
c.threshold = CHANGES_THRESHOLD
78+
c.changes = 0
79+
c.oldchanges = 0
80+
81+
c.initializeClusters()
82+
83+
for i := 0; i < c.iterations && c.shouldStop(); i++ {
84+
c.run()
85+
}
86+
87+
var wg sync.WaitGroup
88+
{
89+
wg.Add(c.number)
90+
}
91+
92+
for j := 0; j < c.number; j++ {
93+
go func(n int) {
94+
defer wg.Done()
95+
96+
l := len(c.points[c.clusters[n].number])
97+
98+
c.clusters[n].data = make([][]float64, 0, l)
99+
100+
fmt.Printf("Cluster no. %02d centroid: %v\n", c.clusters[n].number, c.clusters[n].mean)
101+
102+
for k := 0; k < l; k++ {
103+
c.clusters[n].data = append(c.clusters[n].data, c.dataset[c.points[c.clusters[n].number][k]])
104+
}
105+
}(j)
106+
}
107+
108+
wg.Wait()
109+
110+
c.Unlock()
111+
112+
c.clustered = map[int]int{}
113+
c.points = map[int][]int{}
114+
115+
return nil
116+
}
117+
118+
func (c *kmeansClusterer) Compute() ([]*Cluster, error) {
119+
c.RLock()
120+
defer c.RUnlock()
121+
122+
if c.clusters == nil {
123+
return nil, ErrEmptyClusters
124+
}
125+
126+
return c.clusters, nil
127+
}
128+
129+
func (c *kmeansClusterer) Predict(p []float64) (*Cluster, error) {
130+
c.RLock()
131+
defer c.RUnlock()
132+
133+
if c.clusters == nil {
134+
return nil, ErrEmptyClusters
135+
}
136+
137+
var (
138+
l int
139+
d float64
140+
m float64 = math.MaxFloat64
141+
)
142+
143+
for i := 0; i < len(c.clusters); i++ {
144+
if d = c.distance(p, c.clusters[i].mean); d < m {
145+
m = d
146+
l = i
147+
}
148+
}
149+
150+
return c.clusters[l], nil
151+
}
152+
153+
func (c *kmeansClusterer) PredictFunc() PredictFunc {
154+
c.RLock()
155+
defer c.RUnlock()
156+
157+
return func(p []float64) (*Cluster, error) {
158+
return c.Predict(p)
159+
}
160+
}
161+
162+
func (c *kmeansClusterer) Online(observations chan []float64, done chan bool) chan []*Cluster {
163+
var (
164+
r = make(chan []*Cluster)
165+
)
166+
167+
go func() {
168+
for {
169+
select {
170+
case <-observations:
171+
case <-done:
172+
return
173+
}
174+
}
175+
}()
176+
177+
return r
178+
}
179+
180+
// private
181+
func (c *kmeansClusterer) initializeClusters() {
182+
c.clusters = make([]*Cluster, 0, c.number)
183+
184+
for i := 0; i < c.number; i++ {
185+
c.clusters = append(c.clusters, &Cluster{
186+
number: i,
187+
mean: c.dataset[rand.Intn(len(c.dataset)-1)],
188+
})
189+
}
190+
}
191+
192+
func (c *kmeansClusterer) run() error {
193+
for i := 0; i < len(c.clusters); i++ {
194+
var l = len(c.points[c.clusters[i].number])
195+
196+
if l == 0 {
197+
continue
198+
}
199+
200+
var m = make([]float64, len(c.dataset[0]))
201+
for j := 0; j < l; j++ {
202+
floats.Add(m, c.dataset[c.points[c.clusters[i].number][j]])
203+
}
204+
205+
floats.Scale(1/float64(l), m)
206+
207+
c.clusters[i].mean = m
208+
c.points[c.clusters[i].number] = []int{}
209+
}
210+
211+
for i := 0; i < len(c.dataset); i++ {
212+
var (
213+
n int
214+
d float64
215+
m float64 = math.MaxFloat64
216+
)
217+
218+
for j := 0; j < len(c.clusters); j++ {
219+
if d = c.distance(c.dataset[i], c.clusters[j].mean); d < m {
220+
m = d
221+
n = c.clusters[j].number
222+
}
223+
}
224+
225+
if v, ok := c.clustered[i]; ok {
226+
if v != n {
227+
c.changes++
228+
}
229+
} else {
230+
c.changes++
231+
}
232+
233+
c.clustered[i] = n
234+
c.points[n] = append(c.points[n], i)
235+
}
236+
237+
return nil
238+
}
239+
240+
func (c *kmeansClusterer) shouldStop() bool {
241+
if c.counter == c.threshold {
242+
return false
243+
}
244+
245+
if c.changes == c.oldchanges {
246+
c.counter++
247+
}
248+
249+
c.oldchanges = c.changes
250+
251+
return true
252+
}

0 commit comments

Comments
 (0)