Skip to content

Commit 7138698

Browse files
authored
feat: Add k-means clustering (algorithm-visualizer#27)
1 parent cbb3b2d commit 7138698

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# K-Means Clustering
2+
K-means clustering is a method to partition _n_ observations into _k_ clusters in which each observation belongs to
3+
the cluster with the nearest mean (cluster centers or cluster centroid).
4+
5+
Given a set of observations, where each observation is a d-dimensional real vector, _k-means_ clustering aims to
6+
partition the _n_ observations into _k(≤ n)_ sets so as to minimize the within-cluster sum of squares (i.e. variance).
7+
8+
## Complexity
9+
* **Time**: ![$O(n^{2k+1})$](https://latex.codecogs.com/svg.latex?O(n^{2k+1})) for 2-dimensional real vector
10+
11+
## References
12+
* [Wikipedia](https://en.wikipedia.org/wiki/K-means_clustering)
13+
* [Inspired by kmeans.js.org](https://kmeans.js.org/)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
// import visualization libraries {
2+
const {
3+
Array2DTracer,
4+
Layout,
5+
LogTracer,
6+
Tracer,
7+
VerticalLayout,
8+
ScatterTracer,
9+
Randomize,
10+
} = require('algorithm-visualizer')
11+
// }
12+
13+
// define helper functions {
14+
const shuffle = a => {
15+
const array = a.slice(0)
16+
const copy = []
17+
let n = array.length
18+
19+
while (n) {
20+
let i = Math.floor(Math.random() * n--)
21+
copy.push(array.splice(i, 1)[0])
22+
}
23+
24+
return copy
25+
}
26+
27+
const sum = (x, y) => x + y
28+
const chooseRandomCenters = (data, k) => shuffle(data).slice(0, k)
29+
const pointify = ([x, y]) => `(${x}, ${y})`
30+
const arrayify = a => a.map(pointify)
31+
const stringify = a => arrayify(a).join(', ')
32+
const distance = ([x1, y1], [x2, y2]) => sum(Math.pow(x1 - x2, 2),
33+
Math.pow(y1 - y2, 2))
34+
const col = (a, i) => a.map(p => p[i])
35+
const mean = a => a.reduce(sum, 0) / a.length
36+
const centerOfCluster = cluster => [
37+
mean(col(cluster, 0)),
38+
mean(col(cluster, 1)),
39+
]
40+
const reCalculateCenters = clusters => clusters.map(centerOfCluster)
41+
const areCentersEqual = (c1, c2) => !!c1 && !!c2 && !(c1 < c2 || c2 < c1)
42+
43+
function cluster(data, centers) {
44+
const clusters = centers.map(() => [])
45+
46+
for (let i = 0; i < data.length; i++) {
47+
const point = data[i]
48+
let minDistance = Infinity
49+
let minDistanceIndex = -1
50+
51+
for (let j = 0; j < centers.length; j++) {
52+
const d = distance(point, centers[j])
53+
54+
if (d < minDistance) {
55+
minDistance = d
56+
minDistanceIndex = j
57+
}
58+
}
59+
60+
if (!clusters[minDistanceIndex] instanceof Array) {
61+
clusters[minDistanceIndex] = []
62+
}
63+
64+
clusters[minDistanceIndex].push(point)
65+
}
66+
67+
return clusters
68+
}
69+
70+
// }
71+
72+
// define tracer variables {
73+
const array2dTracer = new Array2DTracer('Grid')
74+
const logTracer = new LogTracer('Console')
75+
const scatterTracer = new ScatterTracer('Scatter')
76+
// }
77+
78+
// define input variables
79+
const unClusteredData = Randomize.Array2D(
80+
{ N: Randomize.Integer({ min: 10, max: 25 }) })
81+
const k = Randomize.Integer(
82+
{ min: 2, max: Math.floor(unClusteredData.length / 5) })
83+
84+
const recenterAndCluster = (originalClusters) => {
85+
const centers = reCalculateCenters(originalClusters)
86+
const clusters = cluster(unClusteredData, centers)
87+
return { centers, clusters }
88+
}
89+
90+
const improve = (loops, clusters, centers) => {
91+
const allowImprove = () => loops < 1000
92+
93+
if (!allowImprove()) {
94+
return { clusters, centers }
95+
}
96+
97+
loops++
98+
99+
const ret = recenterAndCluster(clusters)
100+
101+
// trace {
102+
array2dTracer.set(clusters.map(c => c.map(pointify)))
103+
scatterTracer.set([unClusteredData, ...ret.clusters, ret.centers])
104+
105+
logTracer.println('')
106+
logTracer.println(`Iteration #${loops} Result: `)
107+
logTracer.println(`\tClusters:`)
108+
logTracer.println(
109+
`\t\t${ret.clusters.map(c => stringify(c)).join(`\n\t\t`)}`)
110+
logTracer.println(`\tCenters:`)
111+
logTracer.println(`\t\t${stringify(ret.centers)}`)
112+
logTracer.println('')
113+
114+
Tracer.delay()
115+
// }
116+
117+
if (!allowImprove() || areCentersEqual(centers, ret.centers)) {
118+
return ret
119+
}
120+
121+
return improve(loops, ret.clusters, ret.centers)
122+
}
123+
124+
(function main() {
125+
// visualize {
126+
Layout.setRoot(new VerticalLayout([scatterTracer, array2dTracer, logTracer]))
127+
128+
logTracer.println(`Un-clustered data = ${stringify(unClusteredData)}`)
129+
array2dTracer.set([unClusteredData.map(pointify)])
130+
scatterTracer.set([unClusteredData])
131+
132+
Tracer.delay()
133+
// }
134+
135+
// Start with random centers
136+
const centers = chooseRandomCenters(unClusteredData, k)
137+
138+
// trace {
139+
logTracer.println(
140+
`Initial random selected centers = ${stringify(centers)}`)
141+
scatterTracer.set([unClusteredData, ...[[], []], centers])
142+
143+
Tracer.delay()
144+
// }
145+
146+
// Cluster to the random centers
147+
const clusters = cluster(unClusteredData, centers)
148+
149+
// trace {
150+
logTracer.println(
151+
`Initial clusters = \n\t${clusters.map(stringify).join('\n\t')}`)
152+
array2dTracer.set(clusters.map(c => c.map(pointify)))
153+
scatterTracer.set([unClusteredData, ...clusters, centers])
154+
155+
Tracer.delay()
156+
// }
157+
158+
// start iterations here
159+
const ret = improve(0, clusters, centers)
160+
161+
// trace {
162+
Tracer.delay()
163+
164+
logTracer.println(
165+
`Final clustered data = \n\t${ret.clusters.map(stringify)
166+
.join('\n\t')}`)
167+
logTracer.println(`Best centers = ${stringify(ret.centers)}`)
168+
array2dTracer.set(ret.clusters.map(c => c.map(pointify)))
169+
scatterTracer.set([unClusteredData, ...ret.clusters, ret.centers])
170+
Tracer.delay()
171+
// }
172+
})()

0 commit comments

Comments
 (0)