gen: add k_means
This commit is contained in:
parent
32b8f85a4f
commit
5caa763578
146
gen/k_means/k_means_cluster.cpp
Normal file
146
gen/k_means/k_means_cluster.cpp
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#include <cassert>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
uint32_t xorshift32(uint32_t * state)
|
||||||
|
{
|
||||||
|
uint32_t x = *state;
|
||||||
|
x ^= x << 13;
|
||||||
|
x ^= x >> 17;
|
||||||
|
x ^= x << 5;
|
||||||
|
return *state = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
return: point indicies in `ix`
|
||||||
|
*/
|
||||||
|
void random_k_points(uint32_t * random_state,
|
||||||
|
int k,
|
||||||
|
int num_points,
|
||||||
|
int * point_indices)
|
||||||
|
{
|
||||||
|
int indices[num_points];
|
||||||
|
for (int i = 0; i < num_points; i++) {
|
||||||
|
indices[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < k; i++) {
|
||||||
|
int ix = xorshift32(random_state) % num_points;
|
||||||
|
num_points -= 1;
|
||||||
|
int point_ix = indices[ix];
|
||||||
|
*point_indices++ = point_ix;
|
||||||
|
memmove(&indices[ix], &indices[ix + 1], (num_points - ix) * (sizeof (indices[0])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int dimension>
|
||||||
|
double distance_squared(double a[dimension], double b[dimension])
|
||||||
|
{
|
||||||
|
double sum = 0;
|
||||||
|
for (int i = 0; i < dimension; i++) {
|
||||||
|
double c = a[i] - b[i];
|
||||||
|
sum += c * c;
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int dimension>
|
||||||
|
void calculate_centroid(double points[][dimension], int cluster[], int length, double out[dimension])
|
||||||
|
{
|
||||||
|
for (int i = 0; i < dimension; i++) {
|
||||||
|
out[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int d = 0; d < dimension; d++) {
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
out[d] += points[cluster[i]][d];
|
||||||
|
}
|
||||||
|
out[d] /= (double)length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int dimension>
|
||||||
|
int minimum_distance_centroid(double point[dimension],
|
||||||
|
double centroids[][dimension],
|
||||||
|
int k)
|
||||||
|
{
|
||||||
|
double min_distance = distance_squared<dimension>(point, centroids[0]);
|
||||||
|
int min_ix = 0;
|
||||||
|
for (int centroid_ix = 1; centroid_ix < k; centroid_ix++) {
|
||||||
|
double distance = distance_squared<dimension>(point, centroids[centroid_ix]);
|
||||||
|
if (distance < min_distance) {
|
||||||
|
min_distance = distance;
|
||||||
|
min_ix = centroid_ix;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return min_ix;
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr double epsilon = 0.000001;
|
||||||
|
|
||||||
|
template <int dimension>
|
||||||
|
bool point_equal(double a[dimension], double b[dimension])
|
||||||
|
{
|
||||||
|
for (int i = 0; i < dimension; i++) {
|
||||||
|
if (a[i] - b[i] > epsilon)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int dimension>
|
||||||
|
void set_vector(double dst[dimension], double src[dimension])
|
||||||
|
{
|
||||||
|
for (int i = 0; i < dimension; i++) {
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int dimension>
|
||||||
|
void k_means_cluster(uint32_t * random_state,
|
||||||
|
int k,
|
||||||
|
double points[][dimension],
|
||||||
|
int length,
|
||||||
|
double out[][dimension])
|
||||||
|
{
|
||||||
|
int centroid_indices[k];
|
||||||
|
random_k_points(random_state, k, length, centroid_indices);
|
||||||
|
double centroids[k][dimension];
|
||||||
|
for (int i = 0; i < k; i++) {
|
||||||
|
set_vector<dimension>(centroids[i], /* = */ points[centroid_indices[i]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 268.4 MB stack usage at 1024×1024 px
|
||||||
|
int clusters[k][length];
|
||||||
|
int cluster_lengths[length];
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
// clear cluster lengths
|
||||||
|
for (int i = 0; i < length; i++) { cluster_lengths[i] = 0; }
|
||||||
|
|
||||||
|
// assign each point to the closest centroid
|
||||||
|
for (int point_ix = 0; point_ix < length; point_ix++) {
|
||||||
|
int min_cluster_ix = minimum_distance_centroid<dimension>(points[point_ix], centroids, k);
|
||||||
|
clusters[min_cluster_ix][cluster_lengths[min_cluster_ix]] = point_ix;
|
||||||
|
cluster_lengths[min_cluster_ix]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate new centroids
|
||||||
|
bool converged = true;
|
||||||
|
for (int cluster_ix = 0; cluster_ix < k; cluster_ix++) {
|
||||||
|
double new_centroid[dimension];
|
||||||
|
calculate_centroid<dimension>(points, clusters[cluster_ix], cluster_lengths[cluster_ix], new_centroid);
|
||||||
|
converged &= point_equal<dimension>(new_centroid, centroids[cluster_ix]);
|
||||||
|
set_vector<dimension>(centroids[cluster_ix], /* = */ new_centroid);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (converged)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// return centroids
|
||||||
|
for (int centroid_ix = 0; centroid_ix < k; centroid_ix++) {
|
||||||
|
set_vector<dimension>(out[centroid_ix], /* = */ centroids[centroid_ix]);
|
||||||
|
}
|
||||||
|
}
|
245
gen/k_means/k_means_vq.cpp
Normal file
245
gen/k_means/k_means_vq.cpp
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
#include <cstdio>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cassert>
|
||||||
|
#include <ctime>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cinttypes>
|
||||||
|
#include <cerrno>
|
||||||
|
#include <bit>
|
||||||
|
|
||||||
|
#include "k_means_cluster.cpp"
|
||||||
|
#include "ppm.h"
|
||||||
|
|
||||||
|
#include "twiddle.hpp"
|
||||||
|
#include "color_format.hpp"
|
||||||
|
|
||||||
|
void rgb_to_vectors(const uint8_t * rgb, int width, int height, double vectors[][12])
|
||||||
|
{
|
||||||
|
for (int ty = 0; ty < height / 2; ty++) {
|
||||||
|
for (int tx = 0; tx < width / 2; tx++) {
|
||||||
|
int ai = ((ty * 2) + 0) * width + ((tx * 2) + 0);
|
||||||
|
int bi = ((ty * 2) + 1) * width + ((tx * 2) + 0);
|
||||||
|
int ci = ((ty * 2) + 0) * width + ((tx * 2) + 1);
|
||||||
|
int di = ((ty * 2) + 1) * width + ((tx * 2) + 1);
|
||||||
|
|
||||||
|
vectors[ty * width / 2 + tx][0] = static_cast<double>(rgb[ai * 3 + 0]);
|
||||||
|
vectors[ty * width / 2 + tx][1] = static_cast<double>(rgb[ai * 3 + 1]);
|
||||||
|
vectors[ty * width / 2 + tx][2] = static_cast<double>(rgb[ai * 3 + 2]);
|
||||||
|
|
||||||
|
vectors[ty * width / 2 + tx][3] = static_cast<double>(rgb[bi * 3 + 0]);
|
||||||
|
vectors[ty * width / 2 + tx][4] = static_cast<double>(rgb[bi * 3 + 1]);
|
||||||
|
vectors[ty * width / 2 + tx][5] = static_cast<double>(rgb[bi * 3 + 2]);
|
||||||
|
|
||||||
|
vectors[ty * width / 2 + tx][6] = static_cast<double>(rgb[ci * 3 + 0]);
|
||||||
|
vectors[ty * width / 2 + tx][7] = static_cast<double>(rgb[ci * 3 + 1]);
|
||||||
|
vectors[ty * width / 2 + tx][8] = static_cast<double>(rgb[ci * 3 + 2]);
|
||||||
|
|
||||||
|
vectors[ty * width / 2 + tx][9] = static_cast<double>(rgb[di * 3 + 0]);
|
||||||
|
vectors[ty * width / 2 + tx][10] = static_cast<double>(rgb[di * 3 + 1]);
|
||||||
|
vectors[ty * width / 2 + tx][11] = static_cast<double>(rgb[di * 3 + 2]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void codepixels_to_rgb(double codebook[][12], uint8_t codepixels[], int width, int height, uint8_t * rgb)
|
||||||
|
{
|
||||||
|
for (int ty = 0; ty < height / 2; ty++) {
|
||||||
|
for (int tx = 0; tx < width / 2; tx++) {
|
||||||
|
int codepixel = codepixels[ty * width / 2 + tx];
|
||||||
|
double (&vector)[12] = codebook[codepixel];
|
||||||
|
int ai = ((ty * 2) + 0) * width + ((tx * 2) + 0);
|
||||||
|
int bi = ((ty * 2) + 1) * width + ((tx * 2) + 0);
|
||||||
|
int ci = ((ty * 2) + 0) * width + ((tx * 2) + 1);
|
||||||
|
int di = ((ty * 2) + 1) * width + ((tx * 2) + 1);
|
||||||
|
rgb[ai * 3 + 0] = static_cast<uint8_t>(round(vector[0]));
|
||||||
|
rgb[ai * 3 + 1] = static_cast<uint8_t>(round(vector[1]));
|
||||||
|
rgb[ai * 3 + 2] = static_cast<uint8_t>(round(vector[2]));
|
||||||
|
|
||||||
|
rgb[bi * 3 + 0] = static_cast<uint8_t>(round(vector[3]));
|
||||||
|
rgb[bi * 3 + 1] = static_cast<uint8_t>(round(vector[4]));
|
||||||
|
rgb[bi * 3 + 2] = static_cast<uint8_t>(round(vector[5]));
|
||||||
|
|
||||||
|
rgb[ci * 3 + 0] = static_cast<uint8_t>(round(vector[6]));
|
||||||
|
rgb[ci * 3 + 1] = static_cast<uint8_t>(round(vector[7]));
|
||||||
|
rgb[ci * 3 + 2] = static_cast<uint8_t>(round(vector[8]));
|
||||||
|
|
||||||
|
rgb[di * 3 + 0] = static_cast<uint8_t>(round(vector[9]));
|
||||||
|
rgb[di * 3 + 1] = static_cast<uint8_t>(round(vector[10]));
|
||||||
|
rgb[di * 3 + 2] = static_cast<uint8_t>(round(vector[11]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void palettize_vectors_to_codebook(double codebook[256][12], double vectors[][12], int vectors_length, uint8_t codepixels[])
|
||||||
|
{
|
||||||
|
for (int vector_ix = 0; vector_ix < vectors_length; vector_ix++) {
|
||||||
|
int min_cluster_ix = minimum_distance_centroid(vectors[vector_ix], codebook, 256);
|
||||||
|
assert(min_cluster_ix <= 255 && min_cluster_ix >= 0);
|
||||||
|
codepixels[vector_ix] = min_cluster_ix;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
double total_rgb_error(uint8_t const * const a, uint8_t const * const b, int length)
|
||||||
|
{
|
||||||
|
double error = 0;
|
||||||
|
for (int i = 0; i < length; i++) {
|
||||||
|
double d = ((double)a[i]) - ((double)b[i]);
|
||||||
|
d = d * d;
|
||||||
|
error += d;
|
||||||
|
}
|
||||||
|
return error;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool endswith(const char * s, const char * tail)
|
||||||
|
{
|
||||||
|
int s_len = strlen(s);
|
||||||
|
int tail_len = strlen(tail);
|
||||||
|
|
||||||
|
if (s_len < tail_len) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int start = s_len - tail_len;
|
||||||
|
for (int i = start; i < s_len; i++) {
|
||||||
|
printf("%c %c\n", s[i], tail[i - start]);
|
||||||
|
if (s[i] != tail[i - start])
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<class T, std::endian target_endian = std::endian::little>
|
||||||
|
constexpr T byteswap(const T n)
|
||||||
|
{
|
||||||
|
if (std::endian::native != target_endian) {
|
||||||
|
return std::byteswap<T>(n);
|
||||||
|
} else {
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t color_convert(double vector[12])
|
||||||
|
{
|
||||||
|
uint64_t texel[4];
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
double r = round(vector[i * 3 + 0]);
|
||||||
|
double g = round(vector[i * 3 + 1]);
|
||||||
|
double b = round(vector[i * 3 + 2]);
|
||||||
|
double a = 255;
|
||||||
|
if (r > 255 || g > 255 || b > 255)
|
||||||
|
fprintf(stderr, "%.0f %.0f %.0f\n", r, g ,b);
|
||||||
|
if (r > 255) r = 255;
|
||||||
|
if (g > 255) g = 255;
|
||||||
|
if (b > 255) b = 255;
|
||||||
|
if (r < 0) r = 0;
|
||||||
|
if (g < 0) g = 0;
|
||||||
|
if (b < 0) b = 0;
|
||||||
|
//assert(r <= 255 && g <= 255 && b <= 255);
|
||||||
|
//assert(r >= 0 && g >= 0 && b >= 0);
|
||||||
|
texel[i] = color_format::rgb565(a, r, g, b);
|
||||||
|
}
|
||||||
|
return (texel[3] << 48) | (texel[2] << 32) | (texel[1] << 16) | (texel[0] << 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char * argv[])
|
||||||
|
{
|
||||||
|
if (argc < 3) {
|
||||||
|
printf("argc < 3\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
FILE *f = fopen(argv[1], "rb");
|
||||||
|
if (f == nullptr) {
|
||||||
|
printf("%s: %s\n", argv[1], strerror(errno));
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
fseek(f, 0, SEEK_END);
|
||||||
|
long size = ftell(f);
|
||||||
|
fseek(f, 0, SEEK_SET);
|
||||||
|
|
||||||
|
uint8_t buf[size + 1];
|
||||||
|
ssize_t read_len = fread(buf, size, 1, f);
|
||||||
|
assert(read_len == 1);
|
||||||
|
fclose(f);
|
||||||
|
buf[size] = 0;
|
||||||
|
|
||||||
|
struct ppm_header ppm;
|
||||||
|
int success = ppm_parse(buf, size, &ppm);
|
||||||
|
if (success < 0) {
|
||||||
|
fprintf(stderr, "ppm parse failed\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
assert(ppm.length == ppm.width * ppm.height * 3);
|
||||||
|
|
||||||
|
uint32_t random_state = time(NULL);
|
||||||
|
|
||||||
|
int vectors_length = ppm.width * ppm.height / 4;
|
||||||
|
double vectors[vectors_length][12];
|
||||||
|
rgb_to_vectors(ppm.data, ppm.width, ppm.height, vectors);
|
||||||
|
|
||||||
|
constexpr int codebook_length = 256;
|
||||||
|
double codebook[codebook_length][12];
|
||||||
|
|
||||||
|
int rgb_size = ppm.width * ppm.height * 3;
|
||||||
|
double min_error = 0; //std::numeric_limits<double>::infinity();
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
double new_codebook[codebook_length][12];
|
||||||
|
// find locally-optimal codebook
|
||||||
|
k_means_cluster<12>(&random_state,
|
||||||
|
codebook_length,
|
||||||
|
vectors,
|
||||||
|
vectors_length,
|
||||||
|
new_codebook);
|
||||||
|
|
||||||
|
uint8_t codepixels[vectors_length];
|
||||||
|
palettize_vectors_to_codebook(new_codebook, vectors, vectors_length, codepixels);
|
||||||
|
uint8_t rgb[rgb_size];
|
||||||
|
codepixels_to_rgb(new_codebook, codepixels, ppm.width, ppm.height, rgb);
|
||||||
|
|
||||||
|
double error = total_rgb_error(rgb, ppm.data, rgb_size);
|
||||||
|
if (i % 100 == 0)
|
||||||
|
printf("%d %.0f\n", i, min_error);
|
||||||
|
if (error > min_error) {
|
||||||
|
for (int i = 0; i < codebook_length; i++) {
|
||||||
|
set_vector<12>(codebook[i], new_codebook[i]);
|
||||||
|
}
|
||||||
|
min_error = error;
|
||||||
|
printf("%d new min_error %.0f\n", i, min_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(min_error != std::numeric_limits<double>::infinity());
|
||||||
|
|
||||||
|
uint8_t codepixels[vectors_length];
|
||||||
|
palettize_vectors_to_codebook(codebook, vectors, vectors_length, codepixels);
|
||||||
|
|
||||||
|
printf("w %d h %d\n", ppm.width, ppm.height);
|
||||||
|
FILE *of = fopen(argv[2], "wb");
|
||||||
|
if (endswith(argv[2], ".ppm")) {
|
||||||
|
uint8_t rgb_out[rgb_size];
|
||||||
|
codepixels_to_rgb(codebook, codepixels, ppm.width, ppm.height, rgb_out);
|
||||||
|
|
||||||
|
fprintf(stderr, "writing ppm\n");
|
||||||
|
fprintf(of, "P6\n%d %d\n%d\n", ppm.width, ppm.height, 255);
|
||||||
|
ssize_t write_len = fwrite(rgb_out, rgb_size, 1, of);
|
||||||
|
assert(write_len == 1);
|
||||||
|
} else if (endswith(argv[2], ".vq")) {
|
||||||
|
fprintf(stderr, "writing vq codebook\n");
|
||||||
|
for (int i = 0; i < codebook_length; i++) {
|
||||||
|
uint64_t out = byteswap(color_convert(codebook[i]));
|
||||||
|
ssize_t write_len = fwrite(&out, (sizeof (uint64_t)), 1, of);
|
||||||
|
assert(write_len == 1);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "writing vq codepixels\n");
|
||||||
|
int codepixel_width = ppm.width / 2;
|
||||||
|
int codepixel_height = ppm.height / 2;
|
||||||
|
int max_curve_ix = twiddle::from_xy(codepixel_width - 1, codepixel_height - 1, ppm.width, ppm.height);
|
||||||
|
uint8_t twiddled_codepixels[max_curve_ix];
|
||||||
|
twiddle::texture(twiddled_codepixels, codepixels, codepixel_width, codepixel_height);
|
||||||
|
ssize_t write_len = fwrite(twiddled_codepixels, max_curve_ix + 1, 1, of);
|
||||||
|
assert(write_len == 1);
|
||||||
|
}
|
||||||
|
fclose(of);
|
||||||
|
}
|
61
gen/k_means/ppm.c
Normal file
61
gen/k_means/ppm.c
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
#include "ppm.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
static int advance(uint8_t * buf, int size, int index, uint8_t c)
|
||||||
|
{
|
||||||
|
while (index < size) {
|
||||||
|
if (buf[index] == c && buf[index + 1] != c) {
|
||||||
|
return index + 1;
|
||||||
|
}
|
||||||
|
index += 1;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "end of file: expected `%d`\n", c);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int ppm_parse(uint8_t * buf, int size, struct ppm_header * out)
|
||||||
|
{
|
||||||
|
if (size < 2) {
|
||||||
|
fprintf(stderr, "file too small: %d\n", size);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
bool magic = buf[0] == 'P' && buf[1] == '6';
|
||||||
|
if (!magic) {
|
||||||
|
fprintf(stderr, "invalid magic: %c%c\n", buf[0], buf[1]);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int header[3];
|
||||||
|
int header_ix = 0;
|
||||||
|
int index = 2;
|
||||||
|
uint8_t delimiter = '\n';
|
||||||
|
|
||||||
|
while (header_ix < 3) {
|
||||||
|
index = advance(buf, size - index, index, delimiter);
|
||||||
|
if (buf[index] == '#')
|
||||||
|
continue;
|
||||||
|
|
||||||
|
uint8_t * end;
|
||||||
|
int n = strtol((const char *)&buf[index], (char **)&end, 10);
|
||||||
|
if (end == buf) {
|
||||||
|
fprintf(stderr, "expected integer at index %d\n", index);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
header[header_ix] = n;
|
||||||
|
|
||||||
|
index = end - buf;
|
||||||
|
delimiter = header_ix == 0 ? ' ' : '\n';
|
||||||
|
header_ix += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
index = advance(buf, size - index, index, '\n');
|
||||||
|
out->width = header[0];
|
||||||
|
out->height = header[1];
|
||||||
|
out->colors = header[2];
|
||||||
|
out->data = &buf[index];
|
||||||
|
out->length = size - index;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
13
gen/k_means/ppm.h
Normal file
13
gen/k_means/ppm.h
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
struct ppm_header {
|
||||||
|
int width;
|
||||||
|
int height;
|
||||||
|
int colors;
|
||||||
|
uint8_t * data;
|
||||||
|
int length;
|
||||||
|
};
|
||||||
|
|
||||||
|
int ppm_parse(uint8_t * buf, int size, struct ppm_header * out);
|
92
gen/k_means/python/decode_pvrt.py
Normal file
92
gen/k_means/python/decode_pvrt.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import sys
|
||||||
|
import struct
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
codebook_size = 256 * 2 * 4
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PVRT:
|
||||||
|
texture_data_size: int
|
||||||
|
texture_type: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
codebook: list[int]
|
||||||
|
indices: list[int]
|
||||||
|
|
||||||
|
def parse_pvrt_header(buf):
|
||||||
|
header = buf[0:16]
|
||||||
|
codebook = buf[16:codebook_size + 16]
|
||||||
|
indices = buf[codebook_size + 16:]
|
||||||
|
assert len(header) == 16
|
||||||
|
assert len(codebook) == codebook_size
|
||||||
|
|
||||||
|
assert header[0:4] == b"PVRT"
|
||||||
|
unpacked = struct.unpack('<LLHH', header[4:])
|
||||||
|
texture_data_size, texture_type, width, height = unpacked
|
||||||
|
print(texture_data_size)
|
||||||
|
print(hex(texture_type))
|
||||||
|
print(width)
|
||||||
|
print(height)
|
||||||
|
assert len(indices) + len(codebook) == texture_data_size - 8
|
||||||
|
#assert len(indices) == width * height / 4, (len(indices), width * height / 4)
|
||||||
|
return PVRT(
|
||||||
|
texture_data_size,
|
||||||
|
texture_type,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
codebook,
|
||||||
|
indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
def rgb24(color):
|
||||||
|
r = (color >> 11) & 31
|
||||||
|
g = (color >> 5) & 63
|
||||||
|
b = (color >> 0) & 31
|
||||||
|
return r << 3, g << 2, b << 3
|
||||||
|
|
||||||
|
def get_colors(buf, codebook_ix):
|
||||||
|
codeword = buf[codebook_ix * 2 * 4:][:2 * 4]
|
||||||
|
assert len(codeword) == 2 * 4
|
||||||
|
colors = struct.unpack('<HHHH', codeword)
|
||||||
|
return list(map(rgb24, colors))
|
||||||
|
|
||||||
|
def from_xy(x, y):
|
||||||
|
twiddle_ix = 0
|
||||||
|
i = 0
|
||||||
|
while i <= (20 / 2):
|
||||||
|
twiddle_ix |= ((y >> i) & 1) << (i * 2 + 0)
|
||||||
|
twiddle_ix |= ((x >> i) & 1) << (i * 2 + 1)
|
||||||
|
i += 1
|
||||||
|
return twiddle_ix
|
||||||
|
|
||||||
|
def decode_vq_indices(codebook, indices, width, height):
|
||||||
|
canvas = [0] * width * height
|
||||||
|
for ty in range(height // 2):
|
||||||
|
for tx in range(width // 2):
|
||||||
|
codebook_ix = indices[from_xy(tx, ty)]
|
||||||
|
codeword = get_colors(codebook, codebook_ix)
|
||||||
|
ai = ((ty * 2) + 0) * width + ((tx * 2) + 0)
|
||||||
|
bi = ((ty * 2) + 1) * width + ((tx * 2) + 0)
|
||||||
|
ci = ((ty * 2) + 0) * width + ((tx * 2) + 1)
|
||||||
|
di = ((ty * 2) + 1) * width + ((tx * 2) + 1)
|
||||||
|
print(width, height, ai, ty, tx)
|
||||||
|
canvas[ai] = codeword[0]
|
||||||
|
canvas[bi] = codeword[1]
|
||||||
|
canvas[ci] = codeword[2]
|
||||||
|
canvas[di] = codeword[3]
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
in_filename = sys.argv[1]
|
||||||
|
out_filename = sys.argv[2]
|
||||||
|
|
||||||
|
with open(in_filename, 'rb') as f:
|
||||||
|
buf = f.read()
|
||||||
|
#pvrt = parse_pvrt_header(buf)
|
||||||
|
#canvas = decode_vq_indices(pvrt.codebook, pvrt.indices, pvrt.width, pvrt.height)
|
||||||
|
canvas = decode_vq_indices(buf[:256 * 4 * 2], buf[256*4*2:], 128, 64)
|
||||||
|
|
||||||
|
#palimage = Image.new('RGB', (pvrt.width, pvrt.height))
|
||||||
|
palimage = Image.new('RGB', (128, 64))
|
||||||
|
palimage.putdata(canvas)
|
||||||
|
palimage.save(out_filename)
|
89
gen/k_means/python/km.py
Normal file
89
gen/k_means/python/km.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import random
|
||||||
|
import math
|
||||||
|
from PIL import Image
|
||||||
|
import sys
|
||||||
|
from itertools import chain
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
def random_k_points(k, points):
|
||||||
|
points = list(points)
|
||||||
|
out = []
|
||||||
|
while len(out) < k and len(points) != 0:
|
||||||
|
ix = random.randint(0, len(points) - 1)
|
||||||
|
if points[ix] not in out:
|
||||||
|
value = points.pop(ix)
|
||||||
|
out.append(value)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def distance(point, centroid):
|
||||||
|
x0, y0, z0 = point
|
||||||
|
x1, y1, z1 = centroid
|
||||||
|
xd = (x1 - x0) ** 2
|
||||||
|
yd = (y1 - y0) ** 2
|
||||||
|
zd = (z1 - z0) ** 2
|
||||||
|
return math.sqrt(xd + yd + zd)
|
||||||
|
|
||||||
|
def argmin(l):
|
||||||
|
min = (0, l[0])
|
||||||
|
for ix, v in enumerate(l[1:]):
|
||||||
|
if v < min[1]:
|
||||||
|
min = (ix+1, v)
|
||||||
|
return min[0]
|
||||||
|
|
||||||
|
def calculate_centroid(points):
|
||||||
|
xs = 0
|
||||||
|
ys = 0
|
||||||
|
zs = 0
|
||||||
|
for x, y, z in points:
|
||||||
|
xs += x
|
||||||
|
ys += y
|
||||||
|
zs += z
|
||||||
|
return (xs / len(points), ys / len(points), zs / len(points))
|
||||||
|
|
||||||
|
def k_means_cluster(k, points):
|
||||||
|
# Initialization: choose k centroids (Forgy, Random Partition, etc.)
|
||||||
|
centroids = random_k_points(k, points)
|
||||||
|
assert len(centroids) == k
|
||||||
|
|
||||||
|
# Initialize clusters list
|
||||||
|
clusters = [[] for _ in range(k)]
|
||||||
|
|
||||||
|
# Loop until convergence
|
||||||
|
converged = False
|
||||||
|
while not converged:
|
||||||
|
# Clear previous clusters
|
||||||
|
clusters = [[] for _ in range(k)]
|
||||||
|
|
||||||
|
# Assign each point to the "closest" centroid
|
||||||
|
for point in points:
|
||||||
|
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
|
||||||
|
cluster_assignment = argmin(distances_to_each_centroid)
|
||||||
|
clusters[cluster_assignment].append(point)
|
||||||
|
|
||||||
|
# Calculate new centroids
|
||||||
|
# (the standard implementation uses the mean of all points in a
|
||||||
|
# cluster to determine the new centroid)
|
||||||
|
print(clusters)
|
||||||
|
new_centroids = [calculate_centroid(cluster) for cluster in clusters]
|
||||||
|
|
||||||
|
converged = (new_centroids == centroids)
|
||||||
|
centroids = new_centroids
|
||||||
|
|
||||||
|
if converged:
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def palettize(centroids, point):
|
||||||
|
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
|
||||||
|
cluster_assignment = argmin(distances_to_each_centroid)
|
||||||
|
return cluster_assignment
|
||||||
|
|
||||||
|
output = sys.argv[2]
|
||||||
|
with Image.open(sys.argv[1]) as im:
|
||||||
|
pixels = list(im.convert("RGB").getdata())
|
||||||
|
clusters = k_means_cluster(16, pixels)
|
||||||
|
palette = list(calculate_centroid(cluster) for cluster in clusters)
|
||||||
|
print(palette)
|
||||||
|
palimage = Image.new('P', im.size)
|
||||||
|
palimage.putpalette(map(int, list(chain.from_iterable(palette))))
|
||||||
|
palimage.putdata([palettize(palette, pixel) for pixel in pixels])
|
||||||
|
palimage.save(sys.argv[2])
|
217
gen/k_means/python/km_vq.py
Normal file
217
gen/k_means/python/km_vq.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import random
|
||||||
|
import math
|
||||||
|
from PIL import Image
|
||||||
|
import sys
|
||||||
|
from itertools import chain
|
||||||
|
from pprint import pprint
|
||||||
|
from itertools import starmap
|
||||||
|
import struct
|
||||||
|
|
||||||
|
def random_k_points(k, points):
|
||||||
|
points = list(points)
|
||||||
|
out = []
|
||||||
|
while len(out) < k and len(points) != 0:
|
||||||
|
ix = random.randint(0, len(points) - 1)
|
||||||
|
if points[ix] not in out:
|
||||||
|
value = points.pop(ix)
|
||||||
|
out.append(value)
|
||||||
|
assert len(out) == k
|
||||||
|
return out
|
||||||
|
|
||||||
|
def distance_one(point, centroid):
|
||||||
|
x0, y0, z0 = point
|
||||||
|
x1, y1, z1 = centroid
|
||||||
|
xd = (x1 - x0) ** 2
|
||||||
|
yd = (y1 - y0) ** 2
|
||||||
|
zd = (z1 - z0) ** 2
|
||||||
|
return xd + yd + zd
|
||||||
|
|
||||||
|
def distance(point, centroid):
|
||||||
|
assert len(point) == 4
|
||||||
|
assert len(centroid) == 4
|
||||||
|
return math.sqrt(sum(starmap(distance_one, zip(point, centroid))))
|
||||||
|
|
||||||
|
def argmin(l):
|
||||||
|
min = (0, l[0])
|
||||||
|
for ix, v in enumerate(l[1:]):
|
||||||
|
if v < min[1]:
|
||||||
|
min = (ix+1, v)
|
||||||
|
return min[0]
|
||||||
|
|
||||||
|
def calculate_centroid_one(points):
|
||||||
|
xs = 0
|
||||||
|
ys = 0
|
||||||
|
zs = 0
|
||||||
|
for x, y, z in points:
|
||||||
|
xs += x
|
||||||
|
ys += y
|
||||||
|
zs += z
|
||||||
|
return (xs / len(points), ys / len(points), zs / len(points))
|
||||||
|
|
||||||
|
def calculate_centroid(points):
|
||||||
|
t = tuple(map(calculate_centroid_one, zip(*points)))
|
||||||
|
assert len(t) == 4, t
|
||||||
|
return t
|
||||||
|
|
||||||
|
def k_means_cluster(k, points):
|
||||||
|
# Initialization: choose k centroids (Forgy, Random Partition, etc.)
|
||||||
|
centroids = random_k_points(k, points)
|
||||||
|
assert len(centroids) == k
|
||||||
|
|
||||||
|
# Initialize clusters list
|
||||||
|
clusters = [[] for _ in range(k)]
|
||||||
|
|
||||||
|
# Loop until convergence
|
||||||
|
converged = False
|
||||||
|
while not converged:
|
||||||
|
# Clear previous clusters
|
||||||
|
clusters = [[] for _ in range(k)]
|
||||||
|
|
||||||
|
# Assign each point to the "closest" centroid
|
||||||
|
for point in points:
|
||||||
|
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
|
||||||
|
cluster_assignment = argmin(distances_to_each_centroid)
|
||||||
|
clusters[cluster_assignment].append(point)
|
||||||
|
assert(all(len(c) != 0 for c in clusters))
|
||||||
|
print(clusters)
|
||||||
|
# Calculate new centroids
|
||||||
|
# (the standard implementation uses the mean of all points in a
|
||||||
|
# cluster to determine the new centroid)
|
||||||
|
new_centroids = [calculate_centroid(cluster) for cluster in clusters]
|
||||||
|
|
||||||
|
converged = (new_centroids == centroids)
|
||||||
|
centroids = new_centroids
|
||||||
|
|
||||||
|
if converged:
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def palettize(centroids, point):
|
||||||
|
distances_to_each_centroid = [distance(point, centroid) for centroid in centroids]
|
||||||
|
cluster_assignment = argmin(distances_to_each_centroid)
|
||||||
|
return cluster_assignment
|
||||||
|
|
||||||
|
def rgb565(color):
|
||||||
|
r, g, b = color
|
||||||
|
return r >> 3, g >> 2, b >> 3
|
||||||
|
|
||||||
|
def pixels_to_codebook(pixels, width, height):
|
||||||
|
for ty in range(height // 2):
|
||||||
|
for tx in range(width // 2):
|
||||||
|
ai = ((ty * 2) + 0) * width + ((tx * 2) + 0)
|
||||||
|
bi = ((ty * 2) + 1) * width + ((tx * 2) + 0)
|
||||||
|
ci = ((ty * 2) + 0) * width + ((tx * 2) + 1)
|
||||||
|
di = ((ty * 2) + 1) * width + ((tx * 2) + 1)
|
||||||
|
codeword = pixels[ai], pixels[bi], pixels[ci], pixels[di]
|
||||||
|
yield codeword
|
||||||
|
|
||||||
|
def codebook_codepixels_to_pixels(codebook, codepixels, width, height):
|
||||||
|
canvas = [0] * (width * height)
|
||||||
|
for ty in range(height // 2):
|
||||||
|
for tx in range(width // 2):
|
||||||
|
codepixel = codepixels[ty * width // 2 + tx]
|
||||||
|
assignment = palettize(codebook, codepixel)
|
||||||
|
ap, bp, cp, dp = codebook[assignment]
|
||||||
|
ai = ((ty * 2) + 0) * width + ((tx * 2) + 0)
|
||||||
|
bi = ((ty * 2) + 1) * width + ((tx * 2) + 0)
|
||||||
|
ci = ((ty * 2) + 0) * width + ((tx * 2) + 1)
|
||||||
|
di = ((ty * 2) + 1) * width + ((tx * 2) + 1)
|
||||||
|
canvas[ai] = ap
|
||||||
|
canvas[bi] = bp
|
||||||
|
canvas[ci] = cp
|
||||||
|
canvas[di] = dp
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
def remove_gamma(c):
|
||||||
|
c /= 255
|
||||||
|
if c <= 0.04045:
|
||||||
|
return c / 12.92
|
||||||
|
else:
|
||||||
|
return ((c + 0.055)/1.055) ** 2.4
|
||||||
|
|
||||||
|
def apply_gamma(c):
|
||||||
|
if c <= 0.0031308:
|
||||||
|
c2 = c * 12.92
|
||||||
|
else:
|
||||||
|
c2 = 1.055 * (c ** (1/2.4)) - 0.055
|
||||||
|
return round(c2 * 255)
|
||||||
|
|
||||||
|
def apply_gamma_v(v):
|
||||||
|
assert len(v) == 3
|
||||||
|
return tuple(map(apply_gamma, v))
|
||||||
|
|
||||||
|
def remove_gamma_v(v):
|
||||||
|
assert len(v) == 3
|
||||||
|
return tuple(map(remove_gamma, v))
|
||||||
|
|
||||||
|
for i in range(0, 256):
|
||||||
|
rt = apply_gamma(remove_gamma(i))
|
||||||
|
assert rt == i, (rt, i)
|
||||||
|
|
||||||
|
def mat3x3_mul_v(mat, v):
|
||||||
|
def dot(row):
|
||||||
|
return mat[row][0] * v[0] + \
|
||||||
|
mat[row][1] * v[1] + \
|
||||||
|
mat[row][2] * v[2]
|
||||||
|
|
||||||
|
return tuple((dot(0), dot(1), dot(2)))
|
||||||
|
|
||||||
|
def srgb_to_ciexyz(color):
|
||||||
|
mat = [[0.4124564, 0.3575761, 0.1804375],
|
||||||
|
[0.2126729, 0.7151522, 0.0721750],
|
||||||
|
[0.0193339, 0.1191920, 0.9503041]]
|
||||||
|
return mat3x3_mul_v(mat, color)
|
||||||
|
|
||||||
|
def ciexyz_to_srgb(color):
|
||||||
|
mat = [[ 3.2404542, -1.5371385, -0.4985314],
|
||||||
|
[-0.9692660, 1.8760108, 0.0415560],
|
||||||
|
[ 0.0556434, -0.2040259, 1.0572252]]
|
||||||
|
return mat3x3_mul_v(mat, color)
|
||||||
|
|
||||||
|
def rgb24(color):
|
||||||
|
r, g, b = color
|
||||||
|
return round(r) * (1 << 3), round(g) * (1 << 2), round(b) * (1 << 3)
|
||||||
|
|
||||||
|
for _ in range(256):
|
||||||
|
rcolor = tuple(random.randint(0, 255) for _ in range(3))
|
||||||
|
rtcolor = ciexyz_to_srgb(srgb_to_ciexyz(rcolor))
|
||||||
|
assert rcolor == tuple(map(round, rtcolor))
|
||||||
|
|
||||||
|
def write_binary_vq(f, codebook, codepixels):
|
||||||
|
# ᴎ
|
||||||
|
for colors in codebook:
|
||||||
|
for color in colors:
|
||||||
|
r, g, b = map(round, color)
|
||||||
|
assert r <= 31 and g <= 63 and b <= 31
|
||||||
|
n = (r << 11) | (g << 5) | (b << 0)
|
||||||
|
f.write(struct.pack('<H', n))
|
||||||
|
for codepixel in codepixels:
|
||||||
|
assignment = palettize(codebook, codepixel)
|
||||||
|
assert assignment <= 255 and assignment >= 0
|
||||||
|
f.write(bytes([assignment]))
|
||||||
|
|
||||||
|
def do(filename, output):
|
||||||
|
with Image.open(filename) as im:
|
||||||
|
pixels = list(im.convert("RGB").getdata())
|
||||||
|
#ciexyz_pixels = [srgb_to_ciexyz(p) for p in pixels]
|
||||||
|
rgb565_pixels = [rgb565(p) for p in pixels]
|
||||||
|
width, height = im.size
|
||||||
|
codepixels = list(pixels_to_codebook(rgb565_pixels, width, height))
|
||||||
|
|
||||||
|
clusters = k_means_cluster(256, codepixels)
|
||||||
|
|
||||||
|
codebook = list(calculate_centroid(cluster) for cluster in clusters)
|
||||||
|
|
||||||
|
canvas_ciexyz = codebook_codepixels_to_pixels(codebook, codepixels, width, height)
|
||||||
|
canvas_rgb = [tuple(map(round, rgb24(p))) for p in canvas_ciexyz]
|
||||||
|
|
||||||
|
palimage = Image.new('RGB', im.size)
|
||||||
|
palimage.putdata(canvas_rgb)
|
||||||
|
palimage.save(output)
|
||||||
|
|
||||||
|
with open(output.split('.', maxsplit=1)[0] + '.vq.bin', 'wb') as f:
|
||||||
|
write_binary_vq(f, codebook, codepixels)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
in_file = sys.argv[1]
|
||||||
|
out_file = sys.argv[2]
|
||||||
|
do(in_file, out_file)
|
137
gen/k_means/test/k_means_cluster.cpp
Normal file
137
gen/k_means/test/k_means_cluster.cpp
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "runner.h"
|
||||||
|
#include "../k_means_vq.cpp"
|
||||||
|
|
||||||
|
static bool k_means_vq_0(const char ** scenario)
|
||||||
|
{
|
||||||
|
*scenario = "random_k_points all indices";
|
||||||
|
|
||||||
|
uint32_t random_state = 0x12345678;
|
||||||
|
const int k = 10;
|
||||||
|
const int length = 10;
|
||||||
|
int indices[length];
|
||||||
|
memset(indices, 0xff, (sizeof (indices)));
|
||||||
|
random_k_points(&random_state, k, length, indices);
|
||||||
|
bool duplicate = false;
|
||||||
|
uint32_t seen = 0;
|
||||||
|
for (int i = 0; i < k; i++) {
|
||||||
|
uint32_t bit = 1 << indices[i];
|
||||||
|
if (seen & bit) duplicate = true;
|
||||||
|
seen |= bit;
|
||||||
|
}
|
||||||
|
return
|
||||||
|
duplicate == false &&
|
||||||
|
seen == 0x3ff
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool k_means_vq_1(const char ** scenario)
|
||||||
|
{
|
||||||
|
*scenario = "random_k_points vaugely random";
|
||||||
|
|
||||||
|
uint32_t random_state = 0x12345678;
|
||||||
|
const int length = 30;
|
||||||
|
const int k = 10;
|
||||||
|
int indices[length];
|
||||||
|
const int num_tests = 6;
|
||||||
|
uint32_t seen[num_tests] = {0};
|
||||||
|
|
||||||
|
bool out_of_range = false;
|
||||||
|
bool duplicate = false;
|
||||||
|
for (int j = 0; j < num_tests; j++) {
|
||||||
|
random_k_points(&random_state, k, length, indices);
|
||||||
|
for (int i = 0; i < k; i++) {
|
||||||
|
int point_ix = indices[i];
|
||||||
|
if (point_ix > length) out_of_range = true;
|
||||||
|
uint32_t bit = 1 << point_ix;
|
||||||
|
if (seen[j] & bit) duplicate = true;
|
||||||
|
seen[j] |= bit;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
out_of_range == false &&
|
||||||
|
duplicate == false &&
|
||||||
|
seen[0] != 0 &&
|
||||||
|
seen[1] != 0 &&
|
||||||
|
seen[2] != 0 &&
|
||||||
|
seen[3] != 0 &&
|
||||||
|
seen[4] != 0 &&
|
||||||
|
seen[5] != 0 &&
|
||||||
|
seen[0] < 0xffffffff &&
|
||||||
|
seen[1] < 0xffffffff &&
|
||||||
|
seen[2] < 0xffffffff &&
|
||||||
|
seen[3] < 0xffffffff &&
|
||||||
|
seen[4] < 0xffffffff &&
|
||||||
|
seen[5] < 0xffffffff &&
|
||||||
|
seen[0] != seen[1] && seen[0] != seen[2] && seen[0] != seen[3] && seen[0] != seen[4] && seen[0] != seen[5] &&
|
||||||
|
seen[1] != seen[2] && seen[1] != seen[3] && seen[1] != seen[4] && seen[1] != seen[5] &&
|
||||||
|
seen[2] != seen[3] && seen[2] != seen[4] && seen[2] != seen[5] &&
|
||||||
|
seen[3] != seen[4] && seen[3] != seen[5] &&
|
||||||
|
seen[4] != seen[5]
|
||||||
|
;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool equal(double a, double b)
|
||||||
|
{
|
||||||
|
return a - b < 0.00001;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool k_means_vq_2(const char ** scenario)
|
||||||
|
{
|
||||||
|
*scenario = "calculate_centroid";
|
||||||
|
|
||||||
|
double cluster[2][3] = {
|
||||||
|
{1, 5, 9},
|
||||||
|
{3, 7, 11},
|
||||||
|
};
|
||||||
|
|
||||||
|
double result[3];
|
||||||
|
|
||||||
|
calculate_centroid<3>(&cluster[0], 2, result);
|
||||||
|
|
||||||
|
return
|
||||||
|
equal(result[0], 2) &&
|
||||||
|
equal(result[1], 6);
|
||||||
|
equal(result[2], 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool k_means_vq_3(const char ** scenario)
|
||||||
|
{
|
||||||
|
*scenario = "minimum_distance_centroid";
|
||||||
|
|
||||||
|
int min_ix[2];
|
||||||
|
|
||||||
|
constexpr int k = 5;
|
||||||
|
double centroids[k][2] = {
|
||||||
|
{5, 10},
|
||||||
|
{5, 4},
|
||||||
|
{2, 1},
|
||||||
|
{10, 20},
|
||||||
|
{6, 6},
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
double point[2] = {4, 3};
|
||||||
|
min_ix[0] = minimum_distance_centroid<2>(point, centroids, k);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
double point[2] = {11, 21};
|
||||||
|
min_ix[1] = minimum_distance_centroid<2>(point, centroids, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
min_ix[0] == 1 &&
|
||||||
|
min_ix[1] == 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
test_t k_means_vq_tests[] = {
|
||||||
|
k_means_vq_0,
|
||||||
|
k_means_vq_1,
|
||||||
|
k_means_vq_2,
|
||||||
|
k_means_vq_3,
|
||||||
|
};
|
||||||
|
|
||||||
|
RUNNER(k_means_vq_tests);
|
0
gen/k_means/test/k_means_vq.cpp
Normal file
0
gen/k_means/test/k_means_vq.cpp
Normal file
27
gen/k_means/test/runner.h
Normal file
27
gen/k_means/test/runner.h
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
typedef bool (*test_t)(const char ** scenario);
|
||||||
|
|
||||||
|
#define ANSI_RED "\x1b[31m"
|
||||||
|
#define ANSI_GREEN "\x1b[32m"
|
||||||
|
#define ANSI_RESET "\x1b[0m"
|
||||||
|
|
||||||
|
#define RUNNER(tests) \
|
||||||
|
int main() \
|
||||||
|
{ \
|
||||||
|
int fail_count = 0; \
|
||||||
|
for (int i = 0; i < (sizeof (tests)) / (sizeof (test_t)); i++) { \
|
||||||
|
const char * scenario = NULL; \
|
||||||
|
bool result = tests[i](&scenario); \
|
||||||
|
const char * result_s = result ? "ok" : ANSI_RED "fail" ANSI_RESET; \
|
||||||
|
fail_count += !result; \
|
||||||
|
fprintf(stderr, "%s: %s\n", scenario, result_s); \
|
||||||
|
} \
|
||||||
|
if (fail_count == 0) { \
|
||||||
|
fprintf(stderr, ANSI_GREEN "failed tests: %d\n\n" ANSI_RESET, fail_count); \
|
||||||
|
} else { \
|
||||||
|
fprintf(stderr, ANSI_RED "failed tests: %d\n\n" ANSI_RESET, fail_count); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
return !(fail_count == 0); \
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user