Skip to content

Commit 961a52a

Browse files
committed
Added tests for kmeans estimator
1 parent ce9bb69 commit 961a52a

5 files changed

Lines changed: 55 additions & 83 deletions

File tree

common_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,19 @@ func TestBounds(t *testing.T) {
101101
bounds := bounds(d)
102102

103103
if len(bounds) != 3 {
104-
t.Errorf("Mismatched bounds array length: %d vs %d", len(bounds), l)
104+
t.Errorf("Mismatched bounds array length: %d vs %d\n", len(bounds), l)
105105
}
106106

107107
if bounds[0][0] != 0.1 || bounds[0][1] != 0.7 {
108-
t.Errorf("Invalid bounds for feature #0")
108+
t.Error("Invalid bounds for feature #0")
109109
}
110110

111111
if bounds[1][0] != 0.2 || bounds[1][1] != 0.8 {
112-
t.Errorf("Invalid bounds for feature #1")
112+
t.Error("Invalid bounds for feature #1")
113113
}
114114

115115
if bounds[2][0] != 0.3 || bounds[2][1] != 0.9 {
116-
t.Errorf("Invalid bounds for feature #2")
116+
t.Error("Invalid bounds for feature #2")
117117
}
118118
}
119119

@@ -129,7 +129,7 @@ func TestUniform(t *testing.T) {
129129
for i := 0; i < l; i++ {
130130
u := uniform(d)
131131
if u < 0 || u > 10 {
132-
t.Errorf("Unformly distributed variable out of bounds")
132+
t.Error("Unformly distributed variable out of bounds")
133133
}
134134
}
135135
}

importer_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ func TestImportedLoadDataOfCorrectLengh(t *testing.T) {
1616

1717
d, e := i.Import(f, 0, 2)
1818
if e != nil {
19-
t.Errorf("Error importing data: %s", e.Error())
19+
t.Errorf("Error importing data: %s\n", e.Error())
2020
}
2121

2222
if s != len(d) {
23-
t.Errorf("Imported data size mismatch: %d vs %d", s, len(d))
23+
t.Errorf("Imported data size mismatch: %d vs %d\n", s, len(d))
2424
}
2525
}
2626

@@ -37,11 +37,11 @@ func TestImportedLoadCorrectData(t *testing.T) {
3737

3838
d, e := i.Import(f, 0, 2)
3939
if e != nil {
40-
t.Errorf("Error importing data: %s", e.Error())
40+
t.Errorf("Error importing data: %s\n", e.Error())
4141
}
4242

4343
if !fsliceEqual(d, s) {
44-
t.Error("Imported data mismatch: %v vs %v", d, s)
44+
t.Error("Imported data mismatch: %v vs %v\n", d, s)
4545
}
4646
}
4747

@@ -75,6 +75,6 @@ func BenchmarkImport(b *testing.B) {
7575

7676
_, e := i.Import(f, 4, 5)
7777
if e != nil {
78-
b.Errorf("Error importing data: %s", e.Error())
78+
b.Errorf("Error importing data: %s\n", e.Error())
7979
}
8080
}

kmeans_estimator.go

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"fmt"
55
"math"
66
"math/rand"
7-
"sync"
87
"time"
98

109
"gonum.org/v1/gonum/floats"
@@ -16,14 +15,8 @@ type kmeansEstimator struct {
1615
// variables keeping count of changes of points' membership every iteration. User as a stopping condition.
1716
changes, oldchanges, counter, threshold int
1817

19-
// For online learning only
20-
alpha float64
21-
dimension int
22-
2318
distance DistanceFunc
2419

25-
// slices holding the cluster mapping and sizes. Access is synchronized to avoid read during computation.
26-
mu sync.RWMutex
2720
a, b []int
2821

2922
// slices holding values of centroids of each clusters
@@ -75,25 +68,17 @@ func (c *kmeansEstimator) Estimate(data [][]float64) (int, error) {
7568
)
7669

7770
for i := 0; i < c.max; i++ {
78-
c.number = i
71+
c.number = i + 1
7972

8073
c.learn(data)
8174

82-
fmt.Printf("Learned data for i = %d\n", i)
83-
8475
wks[i] = math.Log(wk(c.d, c.m, c.a))
8576

86-
fmt.Printf("Computed wks for i = %d\n", i)
87-
8877
for j := 0; j < c.max; j++ {
8978
c.learn(c.buildRandomizedSet(size, bounds))
9079

91-
fmt.Printf("Learned randomized dataset for i = %d, j = %d\n", i, j)
92-
9380
bwkbs[j] = math.Log(wk(c.d, c.m, c.a))
9481
one[j] = 1
95-
96-
fmt.Printf("Computed bwkbs for i = %d, j = %d\n", i, j)
9782
}
9883

9984
wkbs[i] = floats.Sum(bwkbs) / float64(c.max)
@@ -103,16 +88,16 @@ func (c *kmeansEstimator) Estimate(data [][]float64) (int, error) {
10388
floats.Mul(bwkbs, bwkbs)
10489

10590
sk[i] = math.Sqrt(floats.Sum(bwkbs) / float64(c.max))
106-
107-
fmt.Printf("WKBS: %v\n", wkbs)
108-
fmt.Printf("SK: %v\n", sk)
10991
}
11092

11193
floats.Scale(math.Sqrt(1+(1/float64(c.max))), sk)
11294

95+
fmt.Printf("WKBS: %v\n", wkbs)
96+
fmt.Printf("SK: %v\n", sk)
97+
11398
for i := 0; i < c.max-1; i++ {
11499
if wkbs[i] >= wkbs[i+1]-sk[i+1] {
115-
estimated = i
100+
estimated = i + 1
116101
break
117102
}
118103
}
@@ -123,8 +108,6 @@ func (c *kmeansEstimator) Estimate(data [][]float64) (int, error) {
123108
// private
124109

125110
func (c *kmeansEstimator) learn(data [][]float64) {
126-
c.mu.Lock()
127-
128111
c.d = data
129112

130113
c.a = make([]int, len(data))
@@ -143,8 +126,6 @@ func (c *kmeansEstimator) learn(data [][]float64) {
143126
}
144127

145128
c.n = nil
146-
147-
c.mu.Unlock()
148129
}
149130

150131
func (c *kmeansEstimator) initializeMeansWithData() {
@@ -191,19 +172,6 @@ func (c *kmeansEstimator) initializeMeansWithData() {
191172
}
192173
}
193174

194-
func (c *kmeansEstimator) initializeMeans() {
195-
c.m = make([][]float64, c.number)
196-
197-
rand.Seed(time.Now().UTC().Unix())
198-
199-
for i := 0; i < c.number; i++ {
200-
c.m[i] = make([]float64, c.dimension)
201-
for j := 0; j < c.dimension; j++ {
202-
c.m[i][j] = 10 * (rand.Float64() - 0.5)
203-
}
204-
}
205-
}
206-
207175
func (c *kmeansEstimator) run() {
208176
var (
209177
l, k, n int = len(c.m[0]), 0, 0

kmeans_estimator_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package clusters
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestKmeansEstimator(t *testing.T) {
8+
const (
9+
C = 10
10+
E = 1
11+
)
12+
13+
var (
14+
f = "data/bus-stops.csv"
15+
i = NewCsvImporter()
16+
)
17+
18+
d, e := i.Import(f, 4, 5)
19+
if e != nil {
20+
t.Errorf("Error importing data: %s\n", e.Error())
21+
}
22+
23+
c, e := KMeansEstimator(1000, C, EuclideanDistance)
24+
if e != nil {
25+
t.Errorf("Error initializing kmeans clusterer: %s\n", e.Error())
26+
}
27+
28+
r, e := c.Estimate(d)
29+
if e != nil {
30+
t.Errorf("Error running test: %s\n", e.Error())
31+
}
32+
33+
if r != E {
34+
t.Errorf("Estimated number of clusters should be %d, it s %d\n", E, r)
35+
}
36+
}

kmeans_test.go

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,51 +16,19 @@ func TestKmeansClusterNumerMatches(t *testing.T) {
1616

1717
d, e := i.Import(f, 4, 5)
1818
if e != nil {
19-
t.Errorf("Error importing data: %s", e.Error())
19+
t.Errorf("Error importing data: %s\n", e.Error())
2020
}
2121

2222
c, e := KMeans(1000, C, EuclideanDistance)
2323
if e != nil {
24-
t.Errorf("Error initializing kmeans clusterer: %s", e.Error())
24+
t.Errorf("Error initializing kmeans clusterer: %s\n", e.Error())
2525
}
2626

2727
if e = c.Learn(d); e != nil {
28-
t.Errorf("Error learning data: %s", e.Error())
28+
t.Errorf("Error learning data: %s\n", e.Error())
2929
}
3030

3131
if len(c.Sizes()) != C {
32-
t.Errorf("Number of clusters does not match: %d vs %d", len(c.Sizes()), C)
32+
t.Errorf("Number of clusters does not match: %d vs %d\n", len(c.Sizes()), C)
3333
}
3434
}
35-
36-
/*func TestKmeansUsingGapStatistic(t *testing.T) {
37-
const (
38-
C = 8
39-
TEST = 10
40-
THRESHOLD = 2
41-
)
42-
43-
var (
44-
f = "data/bus-stops.csv"
45-
i = NewCsvImporter()
46-
)
47-
48-
d, e := i.Import(f, 4, 5)
49-
if e != nil {
50-
t.Errorf("Error importing data: %s", e.Error())
51-
}
52-
53-
c, e := KMeans(1000, C, EuclideanDistance)
54-
if e != nil {
55-
t.Errorf("Error initializing kmeans clusterer: %s", e.Error())
56-
}
57-
58-
r, e := c.Test(d, TEST)
59-
if e != nil {
60-
t.Errorf("Error running test: %s", e.Error())
61-
}
62-
63-
if math.Abs(float64(r.clusters-r.expected)) > THRESHOLD {
64-
t.Errorf("Discrepancy between numer of clusters and expectation: %d vs %d", r.clusters, r.expected)
65-
}
66-
}*/

0 commit comments

Comments
 (0)