-
Notifications
You must be signed in to change notification settings - Fork 0
/
minibatch.go
107 lines (93 loc) · 2.86 KB
/
minibatch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
package kmeaaaaans
import (
"math"
"math/rand"
"runtime"
"sync"
"github.com/panjf2000/ants/v2"
"gonum.org/v1/gonum/mat"
)
type miniBatchKmeans struct {
nClusters uint
tolerance float64
maxIterations uint
maxNoImprobe uint
batchSize uint
initAlgorithm InitAlgorithm
}
var _ Kmeans = (*miniBatchKmeans)(nil)
func updateMiniBatchCentroids(nextCentroids *mat.Dense, centroids *mat.Dense, nSamplesInCluster []uint, accNSamplesInCluster []uint) {
for i := 0; i < len(nSamplesInCluster); i++ {
accNSamplesInCluster[i] += nSamplesInCluster[i]
if 0 < nSamplesInCluster[i] {
w0 := 1.0 / float64(accNSamplesInCluster[i])
nextCentroidRowData := nextCentroids.RawRowView(i)
w1 := w0 * float64(nSamplesInCluster[i])
curCentroidRowData := centroids.RawRowView(i)
for j := 0; j < nextCentroids.RawMatrix().Cols; j++ {
nextCentroidRowData[j] = w0*nextCentroidRowData[j] + (1-w1)*curCentroidRowData[j]
}
} else {
nextCentroids.SetRow(i, centroids.RawRowView(i))
}
}
}
func (k *miniBatchKmeans) Fit(X *mat.Dense) (TrainedKmeans, error) {
nSamples, featDim := X.Dims()
nextCentroids := calcInitialCentroids(X, k.nClusters, k.initAlgorithm)
centroids := mat.NewDense(int(k.nClusters), featDim, nil)
defer ants.Release()
pool, err := ants.NewPool(runtime.NumCPU())
if err != nil {
return &trainedKmeans{}, nil
}
defer pool.Release()
classes := make([]uint, X.RawMatrix().Rows)
accNSamplesInCluster := make([]uint, k.nClusters)
nSamplesInCluster := make([]uint, k.nClusters)
chunkSize := (k.batchSize + uint(runtime.NumCPU()) - 1) / uint(runtime.NumCPU())
minInertia := math.MaxFloat64
minRuns := uint(0)
allIndices := makeSequence(uint(nSamples))
for i := 0; i < int(k.maxIterations) && k.tolerance < calcError(centroids, nextCentroids); i++ {
centroids, nextCentroids = nextCentroids, centroids
maxIndex := uint(nSamples) / k.batchSize
beg := (uint(i) % maxIndex) * k.batchSize
end := beg + k.batchSize
if beg == 0 {
rand.Shuffle(len(allIndices), func(i, j int) { allIndices[i], allIndices[j] = allIndices[j], allIndices[i] })
}
indices := allIndices[beg:end]
chunks := makeChunks(indices, chunkSize)
inertia := 0.0
var wg sync.WaitGroup
var mu sync.Mutex
for _, chunk := range chunks {
chunk := chunk
wg.Add(1)
pool.Submit(func() {
defer wg.Done()
partialInertia := assignCluster(X, centroids, classes, chunk, calcL2Distance)
mu.Lock()
defer mu.Unlock()
inertia += partialInertia
})
}
wg.Wait()
if inertia < minInertia {
minInertia = inertia
minRuns = 0
} else {
minRuns++
}
if k.maxNoImprobe < minRuns {
break
}
accumulateSamples(X, nextCentroids, nSamplesInCluster, classes, indices)
updateMiniBatchCentroids(nextCentroids, centroids, nSamplesInCluster, accNSamplesInCluster)
}
centroids = nextCentroids
return &trainedKmeans{
centroids: centroids,
}, nil
}