Skip to content

Commit a67f9eb

Browse files
committed
Stared work on gap statistic test for quality of k-means++ implementation
1 parent 4c8a3aa commit a67f9eb

8 files changed

Lines changed: 200 additions & 15 deletions

File tree

clusters.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package clusters
22

33
import (
4+
"math"
5+
46
"gonum.org/v1/gonum/floats"
57
)
68

@@ -23,10 +25,16 @@ type SCEvent struct {
2325
Observation []float64
2426
}
2527

26-
/* Clusterer denotes a single operation of learning
28+
/* TestResult represents output of a test performed to measure quality of an algorithm. */
29+
type TestResult struct {
30+
clusters, expected int
31+
}
32+
33+
/* Clusterer denotes a operations of learning and testing
2734
* common for both Hard and Soft clusterers */
2835
type Clusterer interface {
29-
Learn(data [][]float64) error
36+
Learn([][]float64) error
37+
Test([][]float64, ...interface{}) (*TestResult, error)
3038
}
3139

3240
/* HardClusterer defines a set of operations for hard clustering algorithms */
@@ -80,4 +88,8 @@ var (
8088
EuclideanDistance = func(a, b []float64) float64 {
8189
return floats.Distance(a, b, 2)
8290
}
91+
92+
EuclideanDistanceSquared = func(a, b []float64) float64 {
93+
return math.Pow(floats.Distance(a, b, 2), 2)
94+
}
8395
)

common.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package clusters
22

33
import (
44
"container/heap"
5+
"math/rand"
6+
"sync"
7+
8+
"gonum.org/v1/gonum/floats"
59
)
610

711
// struct denoting start and end indices of database portion to be scanned for nearest neighbours by workers in DBSCAN and OPTICS
@@ -63,3 +67,56 @@ func (pq *priorityQueue) Update(item *pItem, value int, priority float64) {
6367
item.p = priority
6468
heap.Fix(pq, item.i)
6569
}
70+
71+
func wk(data [][]float64, centroids [][]float64, mapping []int) float64 {
72+
var (
73+
l = float64(2 * len(data[0]))
74+
wk = make([]float64, len(centroids))
75+
)
76+
77+
for i := 0; i < len(mapping); i++ {
78+
wk[mapping[i]-1] += EuclideanDistanceSquared(centroids[mapping[i]-1], data[i]) / l
79+
}
80+
81+
return floats.Sum(wk)
82+
}
83+
84+
func bounds(data [][]float64) []*[2]float64 {
85+
var (
86+
wg sync.WaitGroup
87+
88+
l = len(data[0])
89+
r = make([]*[2]float64, l)
90+
)
91+
92+
for i := 0; i < l; i++ {
93+
r[i] = &[2]float64{
94+
data[0][i],
95+
data[0][i],
96+
}
97+
}
98+
99+
wg.Add(l)
100+
101+
for i := 0; i < l; i++ {
102+
go func(n int) {
103+
defer wg.Done()
104+
105+
for j := 0; j < len(data); j++ {
106+
if data[j][n] < r[n][0] {
107+
r[n][0] = data[j][n]
108+
} else if data[j][n] > r[n][1] {
109+
r[n][1] = data[j][n]
110+
}
111+
}
112+
}(i)
113+
}
114+
115+
wg.Wait()
116+
117+
return r
118+
}
119+
120+
func uniform(data *[2]float64) float64 {
121+
return rand.Float64()*(data[1]-data[0]) + data[0]
122+
}

dbscan.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,14 @@ func (c *dbscanClusterer) Learn(data [][]float64) error {
102102
return nil
103103
}
104104

105+
func (c *dbscanClusterer) Test(data [][]float64, args ...interface{}) (*TestResult, error) {
106+
if len(data) == 0 {
107+
return nil, ErrEmptySet
108+
}
109+
110+
return nil, nil
111+
}
112+
105113
func (c *dbscanClusterer) Sizes() []int {
106114
c.mu.RLock()
107115
defer c.mu.RUnlock()

importer.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@ import (
88
"strconv"
99
)
1010

11-
type Importer struct {
11+
type Importer interface {
12+
Import(file string, start, end int) ([][]float64, error)
1213
}
1314

14-
func NewImporter() *Importer {
15-
return &Importer{}
15+
type csvImporter struct {
1616
}
1717

18-
func (i *Importer) Import(file string, start, end, size int) ([][]float64, error) {
18+
func NewCsvImporter() Importer {
19+
return &csvImporter{}
20+
}
21+
22+
func (i *csvImporter) Import(file string, start, end int) ([][]float64, error) {
1923
if start < 0 || end < 0 || start > end {
2024
return [][]float64{}, ErrInvalidRange
2125
}
@@ -28,7 +32,7 @@ func (i *Importer) Import(file string, start, end, size int) ([][]float64, error
2832
defer f.Close()
2933

3034
var (
31-
d = make([][]float64, 0, size)
35+
d = make([][]float64, 0)
3236
r = csv.NewReader(bufio.NewReader(f))
3337
s = end - start + 1
3438
g []float64

importer_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ const TOLERANCE = 0.000001
1010
func TestImportedLoadDataOfCorrectLengh(t *testing.T) {
1111
var (
1212
f = "data/test.csv"
13-
i = NewImporter()
13+
i = NewCsvImporter()
1414
s = 3
1515
)
1616

17-
d, e := i.Import(f, 0, 2, 3)
17+
d, e := i.Import(f, 0, 2)
1818
if e != nil {
1919
t.Errorf("Error importing data: %s", e.Error())
2020
}
@@ -27,15 +27,15 @@ func TestImportedLoadDataOfCorrectLengh(t *testing.T) {
2727
func TestImportedLoadCorrectData(t *testing.T) {
2828
var (
2929
f = "data/test.csv"
30-
i = NewImporter()
30+
i = NewCsvImporter()
3131
s = [][]float64{
3232
[]float64{0.1, 0.2, 0.3},
3333
[]float64{0.4, 0.5, 0.6},
3434
[]float64{0.7, 0.8, 0.9},
3535
}
3636
)
3737

38-
d, e := i.Import(f, 0, 2, 3)
38+
d, e := i.Import(f, 0, 2)
3939
if e != nil {
4040
t.Errorf("Error importing data: %s", e.Error())
4141
}
@@ -68,12 +68,12 @@ func fsliceEqual(a, b [][]float64) bool {
6868
func BenchmarkImport(b *testing.B) {
6969
var (
7070
f = "data/bus-stops.csv"
71-
i = NewImporter()
71+
i = NewCsvImporter()
7272
)
7373

7474
b.ResetTimer()
7575

76-
_, e := i.Import(f, 4, 5, 15000)
76+
_, e := i.Import(f, 4, 5)
7777
if e != nil {
7878
b.Errorf("Error importing data: %s", e.Error())
7979
}

kmeans.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package clusters
22

33
import (
4+
"errors"
45
"math"
56
"math/rand"
67
"sync"
@@ -108,6 +109,52 @@ func (c *kmeansClusterer) Learn(data [][]float64) error {
108109
return nil
109110
}
110111

112+
func (c *kmeansClusterer) Test(data [][]float64, args ...interface{}) (*TestResult, error) {
113+
if len(data) == 0 {
114+
return nil, ErrEmptySet
115+
}
116+
117+
clusters, ok := args[0].(int)
118+
if !ok {
119+
return nil, errors.New("Argument #0 is invalid")
120+
}
121+
122+
var (
123+
size = len(data)
124+
bounds = bounds(data)
125+
wks = make([]float64, clusters)
126+
wkbs = make([]float64, clusters)
127+
sk = make([]float64, clusters)
128+
one = make([]float64, clusters)
129+
bwkbs = make([]float64, clusters)
130+
)
131+
132+
for i := 0; i < clusters; i++ {
133+
c.Learn(data)
134+
135+
wks[i] = math.Log(wk(c.d, c.m, c.a))
136+
137+
for j := 0; j < clusters; j++ {
138+
c.Learn(c.buildRandomizedSet(size, bounds))
139+
140+
bwkbs[j] = math.Log(wk(c.d, c.m, c.a))
141+
one[j] = 1
142+
}
143+
144+
wkbs[i] = floats.Sum(bwkbs) / float64(clusters)
145+
146+
floats.Scale(wkbs[i], one)
147+
floats.Sub(bwkbs, one)
148+
floats.Mul(bwkbs, bwkbs)
149+
150+
sk[i] = math.Sqrt(floats.Sum(bwkbs) / float64(clusters))
151+
}
152+
153+
floats.Scale(math.Sqrt(1+(1/float64(clusters))), sk)
154+
155+
return nil, nil
156+
}
157+
111158
func (c *kmeansClusterer) Sizes() []int {
112159
c.mu.RLock()
113160
defer c.mu.RUnlock()
@@ -324,3 +371,20 @@ func (c *kmeansClusterer) check() {
324371

325372
c.oldchanges = c.changes
326373
}
374+
375+
func (c *kmeansClusterer) buildRandomizedSet(size int, bounds []*[2]float64) [][]float64 {
376+
var (
377+
l = len(bounds)
378+
r = make([][]float64, size)
379+
)
380+
381+
for i := 0; i < size; i++ {
382+
r[i] = make([]float64, l)
383+
384+
for j := 0; j < l; j++ {
385+
r[i][j] = uniform(bounds[j])
386+
}
387+
}
388+
389+
return r
390+
}

kmeans_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package clusters
22

33
import (
4+
"math"
45
"testing"
56
)
67

@@ -11,10 +12,10 @@ func TestKmeansClusterNumerMatches(t *testing.T) {
1112

1213
var (
1314
f = "data/bus-stops.csv"
14-
i = NewImporter()
15+
i = NewCsvImporter()
1516
)
1617

17-
d, e := i.Import(f, 4, 5, 15000)
18+
d, e := i.Import(f, 4, 5)
1819
if e != nil {
1920
t.Errorf("Error importing data: %s", e.Error())
2021
}
@@ -32,3 +33,34 @@ func TestKmeansClusterNumerMatches(t *testing.T) {
3233
t.Errorf("Number of clusters does not match: %d vs %d", len(c.Sizes()), C)
3334
}
3435
}
36+
37+
func TestKmeansUsingGapStatistic(t *testing.T) {
38+
const (
39+
C = 8
40+
THRESHOLD = 2
41+
)
42+
43+
var (
44+
f = "data/bus-stops.csv"
45+
i = NewCsvImporter()
46+
)
47+
48+
d, e := i.Import(f, 4, 5)
49+
if e != nil {
50+
t.Errorf("Error importing data: %s", e.Error())
51+
}
52+
53+
c, e := KMeans(1000, C, EuclideanDistance)
54+
if e != nil {
55+
t.Errorf("Error initializing kmeans clusterer: %s", e.Error())
56+
}
57+
58+
r, e := c.Test(d, 20)
59+
if e != nil {
60+
t.Errorf("Error running test: %s", e.Error())
61+
}
62+
63+
if math.Abs(float64(r.clusters-r.expected)) > THRESHOLD {
64+
t.Errorf("Discrepancy between numer of clusters and expectation: %d vs %d", r.clusters, r.expected)
65+
}
66+
}

optics.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ func (c *opticsClusterer) Learn(data [][]float64) error {
135135
return nil
136136
}
137137

138+
func (c *opticsClusterer) Test(data [][]float64, args ...interface{}) (*TestResult, error) {
139+
if len(data) == 0 {
140+
return nil, ErrEmptySet
141+
}
142+
143+
return nil, nil
144+
}
145+
138146
func (c *opticsClusterer) Sizes() []int {
139147
c.mu.RLock()
140148
defer c.mu.RUnlock()

0 commit comments

Comments
 (0)