Skip to content

Commit 6411416

Browse files
committed
Moved function wk to kmeansEstimator
1 parent 961a52a commit 6411416

5 files changed

Lines changed: 35 additions & 37 deletions

File tree

clusters.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,15 @@ var (
9393
}
9494

9595
EuclideanDistanceSquared = func(a, b []float64) float64 {
96-
t := floats.Distance(a, b, 2)
97-
return t * t
96+
var (
97+
s, t float64
98+
)
99+
100+
for i, _ := range a {
101+
t = a[i] - b[i]
102+
s += t * t
103+
}
104+
105+
return s
98106
}
99107
)

common.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"container/heap"
55
"math/rand"
66
"sync"
7-
8-
"gonum.org/v1/gonum/floats"
97
)
108

119
// struct denoting start and end indices of database portion to be scanned for nearest neighbours by workers in DBSCAN and OPTICS
@@ -68,19 +66,6 @@ func (pq *priorityQueue) Update(item *pItem, value int, priority float64) {
6866
heap.Fix(pq, item.i)
6967
}
7068

71-
func wk(data [][]float64, centroids [][]float64, mapping []int) float64 {
72-
var (
73-
l = float64(2 * len(data[0]))
74-
wk = make([]float64, len(centroids))
75-
)
76-
77-
for i := 0; i < len(mapping); i++ {
78-
wk[mapping[i]-1] += EuclideanDistanceSquared(centroids[mapping[i]-1], data[i]) / l
79-
}
80-
81-
return floats.Sum(wk)
82-
}
83-
8469
func bounds(data [][]float64) []*[2]float64 {
8570
var (
8671
wg sync.WaitGroup

common_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func TestBounds(t *testing.T) {
9595

9696
d, e := i.Import(f, 0, 2)
9797
if e != nil {
98-
t.Errorf("Error importing data: %s", e.Error())
98+
t.Errorf("Error importing data: %s\n", e.Error())
9999
}
100100

101101
bounds := bounds(d)

errors.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@ package clusters
33
import "errors"
44

55
var (
6-
ErrEmptySet = errors.New("Empty training set")
7-
ErrNotTrained = errors.New("You need to train the algorithm first")
8-
ErrZeroIterations = errors.New("Number of iterations cannot be less than 1")
9-
ErrOneCluster = errors.New("Number of clusters cannot be less than 2")
10-
ErrZeroEpsilon = errors.New("Epsilon cannot be 0")
11-
ErrZeroMinpts = errors.New("MinPts cannot be 0")
12-
ErrZeroWorkers = errors.New("Number of workers cannot be less than 0")
13-
ErrZeroXi = errors.New("Xi cannot be 0")
14-
ErrInvalidRange = errors.New("Range is invalid")
15-
ErrTestingNotSupported = errors.New("Testing is not supported for this algorithm")
6+
ErrEmptySet = errors.New("Empty training set")
7+
ErrNotTrained = errors.New("You need to train the algorithm first")
8+
ErrZeroIterations = errors.New("Number of iterations cannot be less than 1")
9+
ErrOneCluster = errors.New("Number of clusters cannot be less than 2")
10+
ErrZeroEpsilon = errors.New("Epsilon cannot be 0")
11+
ErrZeroMinpts = errors.New("MinPts cannot be 0")
12+
ErrZeroWorkers = errors.New("Number of workers cannot be less than 0")
13+
ErrZeroXi = errors.New("Xi cannot be 0")
14+
ErrInvalidRange = errors.New("Range is invalid")
1615
)

kmeans_estimator.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package clusters
22

33
import (
4-
"fmt"
54
"math"
65
"math/rand"
76
"time"
@@ -72,12 +71,12 @@ func (c *kmeansEstimator) Estimate(data [][]float64) (int, error) {
7271

7372
c.learn(data)
7473

75-
wks[i] = math.Log(wk(c.d, c.m, c.a))
74+
wks[i] = math.Log(c.wk(c.d, c.m, c.a))
7675

7776
for j := 0; j < c.max; j++ {
7877
c.learn(c.buildRandomizedSet(size, bounds))
7978

80-
bwkbs[j] = math.Log(wk(c.d, c.m, c.a))
79+
bwkbs[j] = math.Log(c.wk(c.d, c.m, c.a))
8180
one[j] = 1
8281
}
8382

@@ -92,9 +91,6 @@ func (c *kmeansEstimator) Estimate(data [][]float64) (int, error) {
9291

9392
floats.Scale(math.Sqrt(1+(1/float64(c.max))), sk)
9493

95-
fmt.Printf("WKBS: %v\n", wkbs)
96-
fmt.Printf("SK: %v\n", sk)
97-
9894
for i := 0; i < c.max-1; i++ {
9995
if wkbs[i] >= wkbs[i+1]-sk[i+1] {
10096
estimated = i + 1
@@ -106,7 +102,6 @@ func (c *kmeansEstimator) Estimate(data [][]float64) (int, error) {
106102
}
107103

108104
// private
109-
110105
func (c *kmeansEstimator) learn(data [][]float64) {
111106
c.d = data
112107

@@ -124,8 +119,6 @@ func (c *kmeansEstimator) learn(data [][]float64) {
124119
c.run()
125120
c.check()
126121
}
127-
128-
c.n = nil
129122
}
130123

131124
func (c *kmeansEstimator) initializeMeansWithData() {
@@ -223,6 +216,19 @@ func (c *kmeansEstimator) check() {
223216
c.oldchanges = c.changes
224217
}
225218

219+
func (c *kmeansEstimator) wk(data [][]float64, centroids [][]float64, mapping []int) float64 {
220+
var (
221+
l = float64(2 * len(data[0]))
222+
wk = make([]float64, len(centroids))
223+
)
224+
225+
for i := 0; i < len(mapping); i++ {
226+
wk[mapping[i]-1] += EuclideanDistanceSquared(centroids[mapping[i]-1], data[i]) / l
227+
}
228+
229+
return floats.Sum(wk)
230+
}
231+
226232
func (c *kmeansEstimator) buildRandomizedSet(size int, bounds []*[2]float64) [][]float64 {
227233
var (
228234
l = len(bounds)

0 commit comments

Comments
 (0)