2024-09-13 05:39:54 -05:00

90 lines
2.6 KiB
Python

import random
import math
from PIL import Image
import sys
from itertools import chain
from pprint import pprint
def random_k_points(k, points):
points = list(points)
out = []
while len(out) < k and len(points) != 0:
ix = random.randint(0, len(points) - 1)
if points[ix] not in out:
value = points.pop(ix)
out.append(value)
return out
def distance(point, centroid):
x0, y0, z0 = point
x1, y1, z1 = centroid
xd = (x1 - x0) ** 2
yd = (y1 - y0) ** 2
zd = (z1 - z0) ** 2
return math.sqrt(xd + yd + zd)
def argmin(l):
min = (0, l[0])
for ix, v in enumerate(l[1:]):
if v < min[1]:
min = (ix+1, v)
return min[0]
def calculate_centroid(points):
xs = 0
ys = 0
zs = 0
for x, y, z in points:
xs += x
ys += y
zs += z
return (xs / len(points), ys / len(points), zs / len(points))
def k_means_cluster(k, points):
# Initialization: choose k centroids (Forgy, Random Partition, etc.)
centroids = random_k_points(k, points)
assert len(centroids) == k
# Initialize clusters list
clusters = [[] for _ in range(k)]
# Loop until convergence
converged = False
while not converged:
# Clear previous clusters
clusters = [[] for _ in range(k)]
# Assign each point to the "closest" centroid
for point in points:
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
cluster_assignment = argmin(distances_to_each_centroid)
clusters[cluster_assignment].append(point)
# Calculate new centroids
# (the standard implementation uses the mean of all points in a
# cluster to determine the new centroid)
print(clusters)
new_centroids = [calculate_centroid(cluster) for cluster in clusters]
converged = (new_centroids == centroids)
centroids = new_centroids
if converged:
return clusters
def palettize(centroids, point):
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
cluster_assignment = argmin(distances_to_each_centroid)
return cluster_assignment
output = sys.argv[2]
with Image.open(sys.argv[1]) as im:
pixels = list(im.convert("RGB").getdata())
clusters = k_means_cluster(16, pixels)
palette = list(calculate_centroid(cluster) for cluster in clusters)
print(palette)
palimage = Image.new('P', im.size)
palimage.putpalette(map(int, list(chain.from_iterable(palette))))
palimage.putdata([palettize(palette, pixel) for pixel in pixels])
palimage.save(sys.argv[2])