90 lines
2.6 KiB
Python
90 lines
2.6 KiB
Python
import random
|
|
import math
|
|
from PIL import Image
|
|
import sys
|
|
from itertools import chain
|
|
from pprint import pprint
|
|
|
|
def random_k_points(k, points):
|
|
points = list(points)
|
|
out = []
|
|
while len(out) < k and len(points) != 0:
|
|
ix = random.randint(0, len(points) - 1)
|
|
if points[ix] not in out:
|
|
value = points.pop(ix)
|
|
out.append(value)
|
|
return out
|
|
|
|
def distance(point, centroid):
|
|
x0, y0, z0 = point
|
|
x1, y1, z1 = centroid
|
|
xd = (x1 - x0) ** 2
|
|
yd = (y1 - y0) ** 2
|
|
zd = (z1 - z0) ** 2
|
|
return math.sqrt(xd + yd + zd)
|
|
|
|
def argmin(l):
|
|
min = (0, l[0])
|
|
for ix, v in enumerate(l[1:]):
|
|
if v < min[1]:
|
|
min = (ix+1, v)
|
|
return min[0]
|
|
|
|
def calculate_centroid(points):
|
|
xs = 0
|
|
ys = 0
|
|
zs = 0
|
|
for x, y, z in points:
|
|
xs += x
|
|
ys += y
|
|
zs += z
|
|
return (xs / len(points), ys / len(points), zs / len(points))
|
|
|
|
def k_means_cluster(k, points):
|
|
# Initialization: choose k centroids (Forgy, Random Partition, etc.)
|
|
centroids = random_k_points(k, points)
|
|
assert len(centroids) == k
|
|
|
|
# Initialize clusters list
|
|
clusters = [[] for _ in range(k)]
|
|
|
|
# Loop until convergence
|
|
converged = False
|
|
while not converged:
|
|
# Clear previous clusters
|
|
clusters = [[] for _ in range(k)]
|
|
|
|
# Assign each point to the "closest" centroid
|
|
for point in points:
|
|
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
|
|
cluster_assignment = argmin(distances_to_each_centroid)
|
|
clusters[cluster_assignment].append(point)
|
|
|
|
# Calculate new centroids
|
|
# (the standard implementation uses the mean of all points in a
|
|
# cluster to determine the new centroid)
|
|
print(clusters)
|
|
new_centroids = [calculate_centroid(cluster) for cluster in clusters]
|
|
|
|
converged = (new_centroids == centroids)
|
|
centroids = new_centroids
|
|
|
|
if converged:
|
|
return clusters
|
|
|
|
def palettize(centroids, point):
|
|
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
|
|
cluster_assignment = argmin(distances_to_each_centroid)
|
|
return cluster_assignment
|
|
|
|
output = sys.argv[2]
|
|
with Image.open(sys.argv[1]) as im:
|
|
pixels = list(im.convert("RGB").getdata())
|
|
clusters = k_means_cluster(16, pixels)
|
|
palette = list(calculate_centroid(cluster) for cluster in clusters)
|
|
print(palette)
|
|
palimage = Image.new('P', im.size)
|
|
palimage.putpalette(map(int, list(chain.from_iterable(palette))))
|
|
palimage.putdata([palettize(palette, pixel) for pixel in pixels])
|
|
palimage.save(sys.argv[2])
|