-
Notifications
You must be signed in to change notification settings - Fork 0
/
lloyd.go
73 lines (62 loc) · 1.82 KB
/
lloyd.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package kmeaaaaans
import (
"runtime"
"sync"
"github.com/panjf2000/ants/v2"
"gonum.org/v1/gonum/mat"
)
type lloydKmeans struct {
nClusters uint
tolerance float64
maxIterations uint
chunkSize uint
initAlgorithm InitAlgorithm
}
var _ Kmeans = (*lloydKmeans)(nil)
func updateLloydCentroids(centroids, nextCentroids *mat.Dense, nSamplesInCluster []uint) {
for i := 0; i < len(nSamplesInCluster); i++ {
if 0 < nSamplesInCluster[i] {
scale := 1.0 / float64(nSamplesInCluster[i])
centroidData := nextCentroids.RawRowView(i)
for j := 0; j < nextCentroids.RawMatrix().Cols; j++ {
centroidData[j] *= scale
}
} else {
nextCentroids.SetRow(i, centroids.RawRowView(i))
}
}
}
func (k *lloydKmeans) Fit(X *mat.Dense) (TrainedKmeans, error) {
nSamples, featDim := X.Dims()
nextCentroids := calcInitialCentroids(X, k.nClusters, k.initAlgorithm)
centroids := mat.NewDense(int(k.nClusters), int(featDim), nil)
defer ants.Release()
pool, err := ants.NewPool(runtime.NumCPU())
if err != nil {
return &trainedKmeans{}, nil
}
defer pool.Release()
classes := make([]uint, nSamples)
indices := makeSequence(uint(nSamples))
chunks := makeChunks(indices, k.chunkSize)
nSamplesInCluster := make([]uint, k.nClusters)
for i := 0; i < int(k.maxIterations) && k.tolerance < calcError(centroids, nextCentroids); i++ {
centroids, nextCentroids = nextCentroids, centroids
var wg sync.WaitGroup
for _, chunk := range chunks {
chunk := chunk
wg.Add(1)
pool.Submit(func() {
defer wg.Done()
assignCluster(X, centroids, classes, chunk, calcL2Distance)
})
}
wg.Wait()
accumulateSamples(X, nextCentroids, nSamplesInCluster, classes, indices)
updateLloydCentroids(centroids, nextCentroids, nSamplesInCluster)
}
centroids = nextCentroids
return &trainedKmeans{
centroids: centroids,
}, nil
}