K-means Algoritm (K평균 군집화 알고리즘)
K-means (MacQueen, 1967) 은 유명한 군집화 (Clustering) 문제를 해결하는 가장 간단한 자율학습 (Unsupervised Learning) 알고리즘중 하나이다. 사전에 정해진 어떤수의 클러스터를 통해서 주어진 데이터 집합을 분류하는 간단하고 쉬운 방법.
k-means 는 partitional clustering 에 속한다.
data 이외에 cluster 의 수 k를 input 으로 하며 이때 k를 seed point 라고 한다. seed point 는 임의로 선택되며 바람직한 cluster 구조에 관한 어떤 지식들이 seed point를 선택하는데에 사용될 수 있다. Forgy' algorithm 과 다른점은 하나의 sample 이 하나의 cluster 에 합류하자마자 곧 cluster 의 centroid 가 다시 계산된다는 것이다. 또한 Forgy' algorithm 이 반복적(iterative) 한 반면에 k-means algorithm 은 data set에서 단지 두 번만의 pass 가 이루어진다. 그 과정은 다음과 같다.
1. 처음에 k cluster 로서 시작한다. 남아있는 n-k sample들에 대해서는 가장 가까이 있는 centroid를 찾는다. 이것에 가장 가까이 있는 centroid를 가지는 것이 확인된 cluster 에 sample을 포함시킨다. 각각의 sample 들이 할당된 후에 할당된 cluster 의 centroid 가 다시 계산된다.
2. 그 data를 두 번 처리한다. 각 sample에 대하여 가장 가까이 있는 centroid를 찾는다. 가장 가까이 있는 centroid를 가진 것으로 확인된 cluster 에 sample을 위치시킨다. (이 step 에서는 어떤 centroid 도 다시 계산하지 않는다.
(reference : AIstudy - http://www.aistudy.com)
위 설명을 바탕으로 한번 구현해 보았다.
실제 데이터들을 바탕으로 써 먹을수 있게끔 구현하였고,
visualization은 알아서 하면 될 듯.
< 500개의 데이터를 k-means 알고리즘으로 군집화(30개의 클래스) >
weight.java
public class weight {
public double [] value;
public int num;
public weight(int length, boolean rnd){
value = new double[length];
if(rnd)
for(int i = 0; i < length; i++)
value[i] = Math.random();
num = -1; // non-clustering
}
public void setNumber(int num){
this.num = num;
}
public int getNumber(){
return num;
}
public double getLength(){
return value.length;
}
public void set(int index, double val){
value[index] = val;
}
public double get(int index){
return value[index];
}
public double distance(weight w){
return Math.sqrt(distanceSq(w));
}
public double distanceSq(weight w){
if(w.getLength() != value.length)
return -1; // error
else{
double distSq = 0;
double d;
for(int i = 0; i < value.length; i++){
d = value[i] - w.get(i);
distSq += d * d;
}
return distSq;
}
}
}
KCluster.java
public class KCluster extends weight{
public int num;
public KCluster(int length, boolean rnd, int num){
super(length, rnd);
this.num = num;
}
public int getNumber(){
return num;
}
public void setWeight(weight w){
value = w.value;
}
}
ClusteringEngine.java
import java.util.ArrayList;
public class ClusteringEngine implements Runnable{
volatile Thread timer;
public ArrayList<weight> dataSet;
public ArrayList<KCluster> kSet;
public int length;
public double threshold = 0.005;
public double err=0;
@Override
public void run() {
// TODO Auto-generated method stub
while(timer == Thread.currentThread()){
try{
Thread.sleep(1000);
}catch(InterruptedException e){ }
// running Method
clustering();
if(!rePosition())
stop();
}
}
public void start(){
timer = new Thread(this);
timer.start();
}
public void stop(){
timer = null;
}
public ClusteringEngine(int length,int dataSize, int kSize){
dataSet = new ArrayList<weight>();
kSet = new ArrayList<KCluster>();
this.length = length;
for(int i = 0; i < dataSize; i++) // initializing dataSet
dataSet.add(new weight(length,true));
for(int i = 0; i < kSize; i++) // initializing KClusterSet
kSet.add(new KCluster(length, true, i));
}
public void clustering(){
for(weight w : dataSet)
w.setNumber(getBestClass(w));
}
public int getBestClass(weight w){ // 웨이트와 가장 가까운 k를 찾아 거기에 해당하는 넘버를 리턴
KCluster min = kSet.get((int)(Math.random() * kSet.size()));
for(KCluster k : kSet)
if(!min.equals(k) && min.distance(w) > k.distance(w))
min = k;
return min.getNumber();
}
public boolean rePosition(){
double avgDist = 0;
for(int i = 0; i < kSet.size(); i++){
weight avgWeight = averaging(i);
KCluster k = kSet.get(i);
avgDist += avgWeight.distance(k);
k.setWeight(avgWeight);
}
avgDist /= kSet.size();
err = avgDist;
if(avgDist > threshold)
return true; // 재배치를 하였으면 true
else
return false; // 아니면 false
}
public weight averaging(int num){ // 클래스 넘버에 해당하는 데이터만 찾아서 평균위치를 찾음.
weight avg = new weight(length, false);
int count = 0;
for(weight w : dataSet)
if(num == w.getNumber()){
for(int i = 0; i < length; i++)
avg.set(i, avg.get(i) + w.get(i));
count++;
}
for(int i = 0; i < length; i++)
avg.set(i, avg.get(i) / count);
return avg;
}
public ArrayList<weight> getDataArray(){
return dataSet;
}
public ArrayList<KCluster> getKArray(){
return kSet;
}
}
'Codes' 카테고리의 다른 글
Flocking in Java - testFrame (0) | 2011.11.01 |
---|---|
Java - 최적화된 회전 방향 결정 (0) | 2011.09.01 |