diff --git a/typescript/package.json b/typescript/package.json index a5d07f06..a9c153cd 100644 --- a/typescript/package.json +++ b/typescript/package.json @@ -45,15 +45,23 @@ "scripts": { "build": "wireit", "build:ts": "wireit", - "prepublishOnly": "npm run build" + "prepublishOnly": "npm run build", + "test": "npm run build && jasmine-browser-runner runSpecs", + "test:serve": "npm run build && jasmine-browser-runner serve" }, "wireit": { "build": { - "dependencies": ["build:ts"] + "dependencies": [ + "build:ts" + ] }, "build:ts": { "command": "tsc --pretty", - "files": ["tsconfig.json", "**/*.ts", "!**/*.d.ts"], + "files": [ + "tsconfig.json", + "**/*.ts", + "!**/*.d.ts" + ], "output": [ ".tsbuildinfo", "**/*.js", @@ -66,10 +74,13 @@ } }, "devDependencies": { - "@types/jasmine": "^3.10.3", + "@types/jasmine": "^3.10.18", "@types/node": "^18.7.17", + "@webgpu/types": "^0.1.54", "jasmine": "^4.0.2", - "typescript": "^4.5.5", + "jasmine-browser-runner": "^2.5.0", + "jasmine-core": "^5.6.0", + "typescript": "^4.9.5", "wireit": "^0.9.5" } } diff --git a/typescript/palettes/palettes_test.ts b/typescript/palettes/palettes_test.ts index bdc1c828..15fc2fe7 100644 --- a/typescript/palettes/palettes_test.ts +++ b/typescript/palettes/palettes_test.ts @@ -17,8 +17,6 @@ import 'jasmine'; -import {Hct} from '../hct/hct.js'; - import {CorePalette} from './core_palette.js'; import {TonalPalette} from './tonal_palette.js'; diff --git a/typescript/quantize-webgpu/kmeans/index.ts b/typescript/quantize-webgpu/kmeans/index.ts new file mode 100644 index 00000000..03eec49f --- /dev/null +++ b/typescript/quantize-webgpu/kmeans/index.ts @@ -0,0 +1,84 @@ +import { setupCompute } from './pipelines/compute.js'; + +export async function extractDominantColorsKMeansGPU( + device: GPUDevice, + pixels: number[], + K: number, + initialCentroidsBuffer: GPUBuffer | null = null +): Promise { + const MAX_ITERATIONS = 256; + const CONVERGENCE_EPS = 0.01; + const CONVERGENCE_CHECK = 8; + + const { + colorCount, + centroidsBuffer, + centroidsDeltaBuffer, + assignPipeline, + updatePipeline, + computeBindGroup + } = await setupCompute(device, pixels, K); + + const stagingCentroidsDeltaBuffer = device.createBuffer({ + label: 'centroids-delta-staging', + size: K * Float32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ + }); + + let encoder = device.createCommandEncoder(); + + if (initialCentroidsBuffer) { + encoder.copyBufferToBuffer( + initialCentroidsBuffer, 0, + centroidsBuffer, 0, + 3 * K * Float32Array.BYTES_PER_ELEMENT + ); + } else { + const centroids = new Float32Array(3 * K); + for (let i = 0; i < 3 * K; i++) { + centroids[i] = Math.random(); + } + device.queue.writeBuffer(centroidsBuffer, 0, centroids); + } + + for (let i = 0; i < MAX_ITERATIONS; i++) { + const assignPass = encoder.beginComputePass(); + assignPass.setPipeline(assignPipeline); + assignPass.setBindGroup(0, computeBindGroup); + assignPass.dispatchWorkgroups(Math.ceil(colorCount / 256)); + assignPass.end(); + + const updatePass = encoder.beginComputePass(); + updatePass.setPipeline(updatePipeline); + updatePass.setBindGroup(0, computeBindGroup); + updatePass.dispatchWorkgroups(Math.ceil(K / 16)); + updatePass.end(); + + if (i !== 0 && i % CONVERGENCE_CHECK === 0) { + encoder.copyBufferToBuffer( + centroidsDeltaBuffer, 0, + stagingCentroidsDeltaBuffer, 0, + K * Float32Array.BYTES_PER_ELEMENT + ); + + const commandBuffer = encoder.finish(); + device.queue.submit([commandBuffer]); + encoder = device.createCommandEncoder(); + + await stagingCentroidsDeltaBuffer.mapAsync(GPUMapMode.READ, 0, K * Float32Array.BYTES_PER_ELEMENT); + const centroidsDeltaData = new Float32Array(stagingCentroidsDeltaBuffer.getMappedRange()); + const deltaSum = centroidsDeltaData.reduce((acc, val) => acc + val, 0); + stagingCentroidsDeltaBuffer.unmap(); + if (deltaSum < CONVERGENCE_EPS) { + console.log(`Convergence reached at iteration ${i}`); + break; + } + } + } + + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + + return centroidsBuffer; +} + diff --git a/typescript/quantize-webgpu/kmeans/pipelines/compute.ts b/typescript/quantize-webgpu/kmeans/pipelines/compute.ts new file mode 100644 index 00000000..f4b26e4b --- /dev/null +++ b/typescript/quantize-webgpu/kmeans/pipelines/compute.ts @@ -0,0 +1,145 @@ +import * as colorUtils from '../../../utils/color_utils.js'; + +interface ComputeResult { + colorCount: number; + centroidsBuffer: GPUBuffer; + centroidsDeltaBuffer: GPUBuffer; + assignPipeline: GPUComputePipeline; + updatePipeline: GPUComputePipeline; + computeBindGroup: GPUBindGroup; +} + +export async function setupCompute( + device: GPUDevice, + pixels: number[], + K: number +): Promise { + const colorHistogram = new Map(); + for (let i = 0; i < pixels.length; i++) { + const pixel = pixels[i]; + const r = colorUtils.redFromArgb(pixel); + const g = colorUtils.greenFromArgb(pixel); + const b = colorUtils.blueFromArgb(pixel); + const key = (r << 16) | (g << 8) | b; + colorHistogram.set(key, (colorHistogram.get(key) ?? 0) + 1); + } + + const colorCount = colorHistogram.size; + const histogramArray = new Float32Array(colorCount * 4); + let i = 0; + for (const [key, count] of colorHistogram) { + const r = (key >> 16) & 0xFF; + const g = (key >> 8) & 0xFF; + const b = key & 0xFF; + + histogramArray[i * 4] = r / 255; + histogramArray[i * 4 + 1] = g / 255; + histogramArray[i * 4 + 2] = b / 255; + histogramArray[i * 4 + 3] = count; + i++; + } + + const countsUniformBuffer = device.createBuffer({ + size: 2 * Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST + }); + device.queue.writeBuffer(countsUniformBuffer, 0, new Uint32Array([K, colorCount])); + + const histogramBuffer = device.createBuffer({ + label: 'histogram-compute', + size: histogramArray.byteLength, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST + }); + device.queue.writeBuffer(histogramBuffer, 0, histogramArray); + + const centroidsBuffer = device.createBuffer({ + label: 'centroids-compute', + size: 3 * K * Float32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | GPUBufferUsage.COPY_SRC + }); + + const clustersBuffer = device.createBuffer({ + label: 'clusters-compute', + size: colorCount * Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + + const centroidsDeltaBuffer = device.createBuffer({ + label: 'centroids-delta-compute', + size: K * Float32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + + const kUniformBuffer = device.createBuffer({ + size: Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST + }); + device.queue.writeBuffer(kUniformBuffer, 0, new Uint32Array([K])); + + const assignModule = device.createShaderModule({ + code: await fetch(new URL('../shaders/assign.wgsl', import.meta.url).toString()).then(res => res.text()) + }); + const updateModule = device.createShaderModule({ + code: await fetch(new URL('../shaders/update.wgsl', import.meta.url).toString()).then(res => res.text()) + }); + + const computeBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'read-only-storage' } + }, + { + binding: 1, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'uniform' } + }, + { + binding: 2, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, + { + binding: 3, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, + { + binding: 4, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }] + }); + + const computeBindGroup = device.createBindGroup({ + layout: computeBindGroupLayout, + entries: [ + { binding: 0, resource: { buffer: histogramBuffer } }, + { binding: 1, resource: { buffer: countsUniformBuffer } }, + { binding: 2, resource: { buffer: centroidsBuffer } }, + { binding: 3, resource: { buffer: clustersBuffer } }, + { binding: 4, resource: { buffer: centroidsDeltaBuffer } } + ] + }); + const computePipelineLayout = device.createPipelineLayout({ + bindGroupLayouts: [computeBindGroupLayout] + }); + + const updatePipeline = device.createComputePipeline({ + layout: computePipelineLayout, + compute: { module: updateModule } + }); + const assignPipeline = device.createComputePipeline({ + layout: computePipelineLayout, + compute: { module: assignModule } + }); + + return { + colorCount, + centroidsBuffer, + centroidsDeltaBuffer, + assignPipeline, + updatePipeline, + computeBindGroup + }; +} \ No newline at end of file diff --git a/typescript/quantize-webgpu/kmeans/shaders/assign.wgsl b/typescript/quantize-webgpu/kmeans/shaders/assign.wgsl new file mode 100644 index 00000000..ce1c5a7d --- /dev/null +++ b/typescript/quantize-webgpu/kmeans/shaders/assign.wgsl @@ -0,0 +1,40 @@ +struct Counts { + centroids: u32, + colors: u32 +}; + +@group(0) @binding(0) var histogram: array; +@group(0) @binding(1) var counts: Counts; +@group(0) @binding(2) var centroids: array; +@group(0) @binding(3) var clusters: array; + +fn dist(a: vec3f, b: vec3f) -> f32 { + return pow((a.x - b.x), 2) + pow((a.y - b.y), 2) + pow((a.z - b.z), 2); +} + +@compute @workgroup_size(256) +fn cs(@builtin(global_invocation_id) id: vec3u) { + if (id.x >= counts.colors) { + return; + } + + let pos = vec3f(histogram[id.x * 4], histogram[id.x * 4 + 1], histogram[id.x * 4 + 2]); + let count = histogram[id.x * 4 + 3]; + + var min_dist = -1.; + var closest = 0u; + + for (var i = 0u; i < counts.centroids; i++) { + let centroid = vec3f(centroids[3*i], centroids[3*i + 1], centroids[3*i + 2]); + if (centroid.x == -1.0 || centroid.y == -1.0 || centroid.z == -1.0) { + continue; + } + let d = dist(pos, centroid); + if (min_dist == -1 || d < min_dist){ + closest = i; + min_dist = d; + } + } + + clusters[id.x] = closest; +} diff --git a/typescript/quantize-webgpu/kmeans/shaders/update.wgsl b/typescript/quantize-webgpu/kmeans/shaders/update.wgsl new file mode 100644 index 00000000..615aa2b9 --- /dev/null +++ b/typescript/quantize-webgpu/kmeans/shaders/update.wgsl @@ -0,0 +1,52 @@ +struct Counts { + centroids: u32, + colors: u32 +}; + +@group(0) @binding(0) var histogram: array; +@group(0) @binding(1) var counts: Counts; +@group(0) @binding(2) var centroids: array; +@group(0) @binding(3) var clusters: array; +@group(0) @binding(4) var centroids_delta: array; + +fn dist(a: vec3f, b: vec3f) -> f32 { + return sqrt(pow((a.x - b.x), 2) + pow((a.y - b.y), 2) + pow((a.z - b.z), 2)); +} + +@compute @workgroup_size(16) +fn cs(@builtin(global_invocation_id) id: vec3u) { + let centroid = id.x; + + if (centroid >= counts.centroids) { + return; + } + + var sum = vec3f(0); + var count = 0u; + + for (var i = 0u; i < counts.colors; i++) { + if (clusters[i] == centroid) { + let pixel = vec3f(histogram[i * 4], histogram[i * 4 + 1], histogram[i * 4 + 2]); + let pixel_count = u32(histogram[i * 4 + 3]); + sum += pixel * f32(pixel_count); + count += pixel_count; + } + } + + if (count > 0u) { + let old_pos = vec3f(centroids[3*centroid], centroids[3*centroid + 1], centroids[3*centroid + 2]); + let new_pos = sum / f32(count); + + centroids[3*centroid] = new_pos.x; + centroids[3*centroid + 1] = new_pos.y; + centroids[3*centroid + 2] = new_pos.z; + + let d = dist(old_pos, new_pos); + centroids_delta[centroid] = d; + } else { + centroids[3*centroid] = -1.0; + centroids[3*centroid + 1] = -1.0; + centroids[3*centroid + 2] = -1.0; + centroids_delta[centroid] = 0.0; + } +} diff --git a/typescript/quantize-webgpu/quantizer_celebi.ts b/typescript/quantize-webgpu/quantizer_celebi.ts new file mode 100644 index 00000000..7727052a --- /dev/null +++ b/typescript/quantize-webgpu/quantizer_celebi.ts @@ -0,0 +1,72 @@ +import { extractDominantColorsWuGPU } from './wu/index.js'; +import { extractDominantColorsKMeansGPU } from './kmeans/index.js'; +import * as colorUtils from '../utils/color_utils.js'; + +export class QuantizerCelebi { + static async quantize(pixels: number[], maxColors: number): Promise { + if (typeof navigator === 'undefined') { + throw new Error('Not in browser environment'); + } + + const adapter = await navigator.gpu.requestAdapter(); + const device = await adapter.requestDevice(); + if (!device) { + throw new Error('WebGPU not supported'); + } + + const textureData = colorUtils.pixelsToTextureData(pixels); + const textureSize = pixels.length; + const texture = device.createTexture({ + size: [textureSize, 1], + format: 'rgba8unorm', + usage: GPUTextureUsage.TEXTURE_BINDING | + GPUTextureUsage.COPY_DST | + GPUTextureUsage.RENDER_ATTACHMENT | + GPUTextureUsage.STORAGE_BINDING | + GPUTextureUsage.COPY_SRC + }); + + device.queue.writeTexture( + { texture }, + textureData, + { bytesPerRow: pixels.length * 4, rowsPerImage: 1 }, + { width: pixels.length, height: 1 } + ); + + const wuResultsBuffer = await extractDominantColorsWuGPU(device, texture, textureSize, maxColors); + const resultsBuffer = await extractDominantColorsKMeansGPU(device, pixels, maxColors, wuResultsBuffer); + + const stagingResultsBuffer = device.createBuffer({ + size: 3 * maxColors * Float32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ + }); + + const encoder = device.createCommandEncoder(); + encoder.copyBufferToBuffer( + resultsBuffer, 0, + stagingResultsBuffer, 0, + 3 * maxColors * Float32Array.BYTES_PER_ELEMENT + ); + device.queue.submit([encoder.finish()]); + + await stagingResultsBuffer.mapAsync(GPUMapMode.READ, 0, 3 * maxColors * Float32Array.BYTES_PER_ELEMENT); + const mappedData = stagingResultsBuffer.getMappedRange(); + const colors = new Float32Array(mappedData.slice(0)); + stagingResultsBuffer.unmap(); + + const result = []; + for (let i = 0; i < colors.length; i += 3) { + const isValid = [colors[i], colors[i + 1], colors[i + 2]].every(x => !isNaN(x) && x >= 0); + if (isValid) { + const r = Math.round(colors[i] * 255); + const g = Math.round(colors[i + 1] * 255); + const b = Math.round(colors[i + 2] * 255); + const argb = colorUtils.argbFromRgb(r, g, b); + result.push(argb); + } + } + + texture.destroy(); + return result; + } +} diff --git a/typescript/quantize-webgpu/quantizer_celebi_test.ts b/typescript/quantize-webgpu/quantizer_celebi_test.ts new file mode 100644 index 00000000..56423e83 --- /dev/null +++ b/typescript/quantize-webgpu/quantizer_celebi_test.ts @@ -0,0 +1,63 @@ +/** + * @license + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { QuantizerCelebi } from './quantizer_celebi.js'; + +const RED = 0xffff0000; +const GREEN = 0xff00ff00; +const BLUE = 0xff0000ff; + +describe('QuantizerCelebi', () => { + it('1R', async () => { + const answer = await QuantizerCelebi.quantize([RED], 128); + expect(answer.length).toBe(1); + expect(answer[0]).toBe(RED); + }); + + it('1G', async () => { + const answer = await QuantizerCelebi.quantize([GREEN], 128); + expect(answer.length).toBe(1); + expect(answer[0]).toBe(GREEN); + }); + + it('1B', async () => { + const answer = await QuantizerCelebi.quantize([BLUE], 128); + expect(answer.length).toBe(1); + expect(answer[0]).toBe(BLUE); + }); + + it('5B', async () => { + const answer = await QuantizerCelebi.quantize([BLUE, BLUE, BLUE, BLUE, BLUE], 128); + expect(answer.length).toBe(1); + expect(answer[0]).toBe(BLUE); + }); + + it('2R 3G', async () => { + const answer = await QuantizerCelebi.quantize([RED, RED, GREEN, GREEN, GREEN], 128); + expect(answer.length).toBe(2); + expect(answer).toContain(RED); + expect(answer).toContain(GREEN); + }); + + it('1R 1G 1B', async () => { + const answer = await QuantizerCelebi.quantize([RED, GREEN, BLUE], 4); + expect(answer.length).toBe(3); + expect(answer).toContain(RED); + expect(answer).toContain(GREEN); + expect(answer).toContain(BLUE); + }); +}); diff --git a/typescript/quantize-webgpu/wu/index.ts b/typescript/quantize-webgpu/wu/index.ts new file mode 100644 index 00000000..042a399b --- /dev/null +++ b/typescript/quantize-webgpu/wu/index.ts @@ -0,0 +1,119 @@ +import { setupBuildHistogram } from './pipelines/buildHistogram.js'; +import { setupComputeMoments } from './pipelines/computeMoments.js'; +import { setupCreateBox } from './pipelines/createBox.js'; +import { setupCreateResult } from './pipelines/createResult.js'; + +export async function extractDominantColorsWuGPU( + device: GPUDevice, + texture: GPUTexture, + textureSize: number, + K: number +): Promise { + const WORKGROUP_SIZE = 16; + + const TOTAL_SIZE = 35937; + const { + weightsBuffer, + momentsRBuffer, + momentsGBuffer, + momentsBBuffer, + momentsBuffer: mBuffer, + buildHistogramPipeline, + inputBindGroup, + buildHistogramBindGroup, + buildHistogramBindGroupLayout + } = await setupBuildHistogram(device, texture); + + const { + computeMomentsAxisBindGroups, + computeMomentsPipeline + } = await setupComputeMoments(device, buildHistogramBindGroupLayout); + + const { + momentsBuffer, + momentsBindGroup, + totalCubesNumUniformBuffer, + momentsBindGroupLayout, + cubesBuffer, + cubesBindGroup, + createBoxPipeline + } = await setupCreateBox(device, K); + + const { + resultsBuffer, + cubesResultBindGroup, + resultsBindGroup, + createResultPipeline + } = await setupCreateResult(device, K, momentsBindGroupLayout, cubesBuffer, totalCubesNumUniformBuffer); + + let encoder = device.createCommandEncoder(); + const buildHistogramPass = encoder.beginComputePass(); + buildHistogramPass.setPipeline(buildHistogramPipeline); + buildHistogramPass.setBindGroup(0, inputBindGroup); + buildHistogramPass.setBindGroup(1, buildHistogramBindGroup); + buildHistogramPass.dispatchWorkgroups(Math.ceil(textureSize / 256)); + buildHistogramPass.end(); + + const workGroupsPerDim = Math.ceil(32 / WORKGROUP_SIZE); + const momentPass = encoder.beginComputePass(); + momentPass.setPipeline(computeMomentsPipeline); + momentPass.setBindGroup(0, buildHistogramBindGroup); + for (let axis = 0; axis < 3; axis++) { + momentPass.setBindGroup(1, computeMomentsAxisBindGroups[axis]); + momentPass.dispatchWorkgroups(workGroupsPerDim, workGroupsPerDim); + } + momentPass.end(); + + encoder.copyBufferToBuffer( + momentsRBuffer, 0, + momentsBuffer, 0, + TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT + ); + encoder.copyBufferToBuffer( + momentsGBuffer, 0, + momentsBuffer, TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT, + TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT + ); + encoder.copyBufferToBuffer( + momentsBBuffer, 0, + momentsBuffer, 2 * TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT, + TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT + ); + encoder.copyBufferToBuffer( + weightsBuffer, 0, + momentsBuffer, 3 * TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT, + TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT + ); + encoder.copyBufferToBuffer( + mBuffer, 0, + momentsBuffer, 4 * TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT, + TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT + ); + device.queue.submit([encoder.finish()]); + + for (let i = 1; i < K; i++) { + encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(createBoxPipeline); + device.queue.writeBuffer(totalCubesNumUniformBuffer, 0, new Uint32Array([i])); + + pass.setBindGroup(0, momentsBindGroup); + pass.setBindGroup(1, cubesBindGroup); + pass.dispatchWorkgroups(1); + pass.end(); + device.queue.submit([encoder.finish()]); + } + + encoder = device.createCommandEncoder(); + const pass = encoder.beginComputePass(); + pass.setPipeline(createResultPipeline); + pass.setBindGroup(0, momentsBindGroup); + pass.setBindGroup(1, cubesResultBindGroup); + pass.setBindGroup(2, resultsBindGroup); + pass.dispatchWorkgroups(1); + pass.end(); + device.queue.submit([encoder.finish()]); + await device.queue.onSubmittedWorkDone(); + + return resultsBuffer; +} diff --git a/typescript/quantize-webgpu/wu/pipelines/buildHistogram.ts b/typescript/quantize-webgpu/wu/pipelines/buildHistogram.ts new file mode 100644 index 00000000..2ff0a282 --- /dev/null +++ b/typescript/quantize-webgpu/wu/pipelines/buildHistogram.ts @@ -0,0 +1,115 @@ +interface BuildHistogramResult { + weightsBuffer: GPUBuffer; + momentsRBuffer: GPUBuffer; + momentsGBuffer: GPUBuffer; + momentsBBuffer: GPUBuffer; + momentsBuffer: GPUBuffer; + buildHistogramPipeline: GPUComputePipeline; + inputBindGroup: GPUBindGroup; + buildHistogramBindGroup: GPUBindGroup; + buildHistogramBindGroupLayout: GPUBindGroupLayout; +} + +export async function setupBuildHistogram(device: GPUDevice, texture: GPUTexture): Promise { + const histogramSize = 35937; + const weightsBuffer = device.createBuffer({ + label: 'weights', + size: histogramSize * 4, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + const momentsRBuffer = device.createBuffer({ + label: 'moments_r', + size: histogramSize * 4, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + const momentsBBuffer = device.createBuffer({ + label: 'moments_b', + size: histogramSize * 4, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + const momentsGBuffer = device.createBuffer({ + label: 'moments_g', + size: histogramSize * 4, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + const momentsBuffer = device.createBuffer({ + label: 'moments', + size: histogramSize * 4, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + + const inputBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + texture: { sampleType: 'float', viewDimension: '2d' } + }] + }); + + const buildHistogramBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, + { + binding: 1, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, { + binding: 2, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, { + binding: 3, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, { + binding: 4, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }] + }); + + const inputBindGroup = device.createBindGroup({ + layout: inputBindGroupLayout, + entries: [ + { binding: 0, resource: texture.createView() } + ] + }); + + const buildHistogramBindGroup = device.createBindGroup({ + layout: buildHistogramBindGroupLayout, + entries: [ + { binding: 0, resource: { buffer: weightsBuffer } }, + { binding: 1, resource: { buffer: momentsRBuffer } }, + { binding: 2, resource: { buffer: momentsGBuffer } }, + { binding: 3, resource: { buffer: momentsBBuffer } }, + { binding: 4, resource: { buffer: momentsBuffer } }, + ] + }); + const buildHistogramPipelineLayout = device.createPipelineLayout({ + bindGroupLayouts: [inputBindGroupLayout, buildHistogramBindGroupLayout] + }); + + const buildHistogramModule = device.createShaderModule({ + code: await fetch(new URL('../shaders/build_histogram.wgsl', import.meta.url).toString()).then(res => res.text()) + }); + + const buildHistogramPipeline = device.createComputePipeline({ + layout: buildHistogramPipelineLayout, + compute: { module: buildHistogramModule } + }); + + return { + weightsBuffer, + momentsRBuffer, + momentsGBuffer, + momentsBBuffer, + momentsBuffer, + buildHistogramPipeline, + inputBindGroup, + buildHistogramBindGroup, + buildHistogramBindGroupLayout + }; +} \ No newline at end of file diff --git a/typescript/quantize-webgpu/wu/pipelines/computeMoments.ts b/typescript/quantize-webgpu/wu/pipelines/computeMoments.ts new file mode 100644 index 00000000..7123fdf6 --- /dev/null +++ b/typescript/quantize-webgpu/wu/pipelines/computeMoments.ts @@ -0,0 +1,52 @@ +interface ComputeMomentsResult { + computeMomentsAxisBindGroups: GPUBindGroup[]; + computeMomentsPipeline: GPUComputePipeline; +} + +export async function setupComputeMoments( + device: GPUDevice, + momentsBindGroupLayout: GPUBindGroupLayout +): Promise { + const computeMomentsAxisBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'uniform' } + }] + }); + + const computeMomentsAxisBindGroups = []; + for (let axis = 0; axis < 3; axis++) { + const axisUniformBuffer = device.createBuffer({ + size: Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.UNIFORM, + mappedAtCreation: true + }); + new Uint32Array(axisUniformBuffer.getMappedRange()).set([axis]); + axisUniformBuffer.unmap(); + + const bindGroup = device.createBindGroup({ + layout: computeMomentsAxisBindGroupLayout, + entries: [ + { binding: 0, resource: { buffer: axisUniformBuffer } } + ] + }); + computeMomentsAxisBindGroups.push(bindGroup); + } + + const computeMomentsModule = device.createShaderModule({ + code: await fetch(new URL('../shaders/compute_moments.wgsl', import.meta.url).toString()).then(res => res.text()) + }); + const computeMomentsPipelineLayout = device.createPipelineLayout({ + bindGroupLayouts: [momentsBindGroupLayout, computeMomentsAxisBindGroupLayout] + }); + const computeMomentsPipeline = device.createComputePipeline({ + layout: computeMomentsPipelineLayout, + compute: { module: computeMomentsModule } + }); + + return { + computeMomentsAxisBindGroups, + computeMomentsPipeline + }; +} \ No newline at end of file diff --git a/typescript/quantize-webgpu/wu/pipelines/createBox.ts b/typescript/quantize-webgpu/wu/pipelines/createBox.ts new file mode 100644 index 00000000..21b3d93f --- /dev/null +++ b/typescript/quantize-webgpu/wu/pipelines/createBox.ts @@ -0,0 +1,100 @@ +interface CreateBoxResult { + momentsBuffer: GPUBuffer; + momentsBindGroup: GPUBindGroup; + cubesBuffer: GPUBuffer; + totalCubesNumUniformBuffer: GPUBuffer; + momentsBindGroupLayout: GPUBindGroupLayout; + cubesBindGroup: GPUBindGroup; + createBoxPipeline: GPUComputePipeline; +} + +export async function setupCreateBox(device: GPUDevice, K: number): Promise { + const SIDE_LENGTH = 33; + const TOTAL_SIZE = 35937; + + const momentsBuffer = device.createBuffer({ + size: 5 * TOTAL_SIZE * Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST + }); + const momentsBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'read-only-storage' } + }] + }); + const momentsBindGroup = device.createBindGroup({ + layout: momentsBindGroupLayout, + entries: [ + { binding: 0, resource: { buffer: momentsBuffer } } + ] + }); + + const cubesBuffer = device.createBuffer({ + size: 6 * K * Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST + }); + device.queue.writeBuffer(cubesBuffer, 0, new Uint32Array([0, SIDE_LENGTH - 1, 0, SIDE_LENGTH - 1, 0, SIDE_LENGTH - 1])); + + const variancesBuffer = device.createBuffer({ + size: K * Float32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE + }); + const currentCubeIdxBuffer = device.createBuffer({ + size: Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE + }); + const totalCubesNumUniformBuffer = device.createBuffer({ + size: Uint32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST + }); + const cubesBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, { + binding: 1, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, { + binding: 2, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }, { + binding: 3, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'uniform' } + }] + }); + const cubesBindGroup = device.createBindGroup({ + layout: cubesBindGroupLayout, + entries: [ + { binding: 0, resource: { buffer: cubesBuffer } }, + { binding: 1, resource: { buffer: variancesBuffer } }, + { binding: 2, resource: { buffer: currentCubeIdxBuffer } }, + { binding: 3, resource: { buffer: totalCubesNumUniformBuffer } } + ] + }); + + const createBoxModule = device.createShaderModule({ + code: await fetch(new URL('../shaders/create_box.wgsl', import.meta.url).toString()).then(res => res.text()) + }); + const createBoxPipelineLayout = device.createPipelineLayout({ + bindGroupLayouts: [momentsBindGroupLayout, cubesBindGroupLayout] + }); + const createBoxPipeline = device.createComputePipeline({ + layout: createBoxPipelineLayout, + compute: { module: createBoxModule } + }); + + return { + momentsBuffer, + momentsBindGroup, + cubesBuffer, + totalCubesNumUniformBuffer, + momentsBindGroupLayout, + cubesBindGroup, + createBoxPipeline + }; +} \ No newline at end of file diff --git a/typescript/quantize-webgpu/wu/pipelines/createResult.ts b/typescript/quantize-webgpu/wu/pipelines/createResult.ts new file mode 100644 index 00000000..0093b9d3 --- /dev/null +++ b/typescript/quantize-webgpu/wu/pipelines/createResult.ts @@ -0,0 +1,69 @@ +interface CreateResultResult { + resultsBuffer: GPUBuffer; + cubesResultBindGroup: GPUBindGroup; + resultsBindGroup: GPUBindGroup; + createResultPipeline: GPUComputePipeline; +} + +export async function setupCreateResult( + device: GPUDevice, + K: number, + momentsBindGroupLayout: GPUBindGroupLayout, + cubesBuffer: GPUBuffer, + totalCubesNumUniformBuffer: GPUBuffer +): Promise { + const cubesResultBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'read-only-storage' } + }, { + binding: 1, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'uniform' } + }] + }); + const cubesResultBindGroup = device.createBindGroup({ + layout: cubesResultBindGroupLayout, + entries: [ + { binding: 0, resource: { buffer: cubesBuffer } }, + { binding: 1, resource: { buffer: totalCubesNumUniformBuffer } } + ] + }); + + const resultsBuffer = device.createBuffer({ + size: 3 * K * Float32Array.BYTES_PER_ELEMENT, + usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC + }); + const resultsBindGroupLayout = device.createBindGroupLayout({ + entries: [{ + binding: 0, + visibility: GPUShaderStage.COMPUTE, + buffer: { type: 'storage' } + }] + }); + const resultsBindGroup = device.createBindGroup({ + layout: resultsBindGroupLayout, + entries: [ + { binding: 0, resource: { buffer: resultsBuffer } } + ] + }); + + const createResultModule = device.createShaderModule({ + code: await fetch(new URL('../shaders/create_result.wgsl', import.meta.url).toString()).then(res => res.text()) + }); + const createResultPipelineLayout = device.createPipelineLayout({ + bindGroupLayouts: [momentsBindGroupLayout, cubesResultBindGroupLayout, resultsBindGroupLayout] + }); + const createResultPipeline = device.createComputePipeline({ + layout: createResultPipelineLayout, + compute: { module: createResultModule } + }); + + return { + resultsBuffer, + cubesResultBindGroup, + resultsBindGroup, + createResultPipeline + }; +} \ No newline at end of file diff --git a/typescript/quantize-webgpu/wu/shaders/build_histogram.wgsl b/typescript/quantize-webgpu/wu/shaders/build_histogram.wgsl new file mode 100644 index 00000000..0b544689 --- /dev/null +++ b/typescript/quantize-webgpu/wu/shaders/build_histogram.wgsl @@ -0,0 +1,43 @@ +@group(0) @binding(0) var tex: texture_2d; +@group(1) @binding(0) var weights: array>; +@group(1) @binding(1) var moments_r: array>; +@group(1) @binding(2) var moments_g: array>; +@group(1) @binding(3) var moments_b: array>; +@group(1) @binding(4) var moments: array>; + +const INDEX_BITS = 5u; + +fn get_index(r: u32, g: u32, b: u32) -> u32 { + return (r << (2 * INDEX_BITS)) + (r << (INDEX_BITS + 1)) + r + (g << INDEX_BITS) + g + b; +} + +@compute @workgroup_size(256) +fn cs(@builtin(global_invocation_id) id: vec3u) { + let dimensions = textureDimensions(tex); + let width = u32(dimensions.x); + let height = u32(dimensions.y); + + let pointId = id.x + id.y * width; + + if (pointId >= width * height) { + return; + } + + let pixel = textureLoad(tex, id.xy, 0); + + let r = u32(pixel.r * 255.0); + let g = u32(pixel.g * 255.0); + let b = u32(pixel.b * 255.0); + + let bits_to_remove = 8u - INDEX_BITS; + let ir = (r >> bits_to_remove) + 1u; + let ig = (g >> bits_to_remove) + 1u; + let ib = (b >> bits_to_remove) + 1u; + let index = get_index(ir, ig, ib); + + atomicAdd(&weights[index], 1u); + atomicAdd(&moments_r[index], r); + atomicAdd(&moments_g[index], g); + atomicAdd(&moments_b[index], b); + atomicAdd(&moments[index], r * r + g * g + b * b); +} diff --git a/typescript/quantize-webgpu/wu/shaders/compute_moments.wgsl b/typescript/quantize-webgpu/wu/shaders/compute_moments.wgsl new file mode 100644 index 00000000..63578db7 --- /dev/null +++ b/typescript/quantize-webgpu/wu/shaders/compute_moments.wgsl @@ -0,0 +1,60 @@ +@group(0) @binding(0) var weights: array; +@group(0) @binding(1) var moments_r: array; +@group(0) @binding(2) var moments_g: array; +@group(0) @binding(3) var moments_b: array; +@group(0) @binding(4) var moments: array; + +@group(1) @binding(0) var axis: u32; + +const INDEX_BITS = 5u; +const SIDE_LENGTH = 33u; + +fn get_index(r: u32, g: u32, b: u32) -> u32 { + return (r << (2 * INDEX_BITS)) + (r << (INDEX_BITS + 1)) + r + (g << INDEX_BITS) + g + b; +} + +@compute @workgroup_size(16, 16) +fn cs(@builtin(global_invocation_id) id: vec3u) { + let x = id.x + 1u; + let y = id.y + 1u; + + if (x >= SIDE_LENGTH || y >= SIDE_LENGTH) { + return; + } + + var index = 0u; + var sum_weights = 0u; + var sum_moments_r = 0u; + var sum_moments_g = 0u; + var sum_moments_b = 0u; + var sum_moments = 0f; + for (var i = 1u; i < SIDE_LENGTH; i++) { + if (axis == 0u) { + index = get_index(i, x, y); + } else if (axis == 1u) { + index = get_index(x, i, y); + } else { + index = get_index(x, y, i); + } + + sum_weights += weights[index]; + sum_moments_r += moments_r[index]; + sum_moments_g += moments_g[index]; + sum_moments_b += moments_b[index]; + + // to prevent u32 overflow in moments, they are stored as f32 bitcasted to u32 + // after the first axis pass, they are all bitcasted f32 + // using f32 initially is not possible since atomic operations are used in build_histogram + if (axis == 0) { + sum_moments += f32(moments[index]); + } else { + sum_moments += bitcast(moments[index]); + } + + weights[index] = sum_weights; + moments_r[index] = sum_moments_r; + moments_g[index] = sum_moments_g; + moments_b[index] = sum_moments_b; + moments[index] = bitcast(sum_moments); + } +} diff --git a/typescript/quantize-webgpu/wu/shaders/create_box.wgsl b/typescript/quantize-webgpu/wu/shaders/create_box.wgsl new file mode 100644 index 00000000..561b1b63 --- /dev/null +++ b/typescript/quantize-webgpu/wu/shaders/create_box.wgsl @@ -0,0 +1,284 @@ +const INDEX_BITS = 5u; +const SIDE_LENGTH = 33u; +const TOTAL_SIZE = 35937u; + +struct Box { + r0: u32, + r1: u32, + g0: u32, + g1: u32, + b0: u32, + b1: u32 +} + +struct Moments { + r: array, + g: array, + b: array, + w: array, + quad: array +} + +var cut_variances_r: array; +var cut_variances_g: array; +var cut_variances_b: array; +var best_cut: array; + +@group(0) @binding(0) var moments: Moments; + +@group(1) @binding(0) var cubes: array; +@group(1) @binding(1) var variances: array; +@group(1) @binding(2) var current_cube_idx: u32; +@group(1) @binding(3) var total_cubes_num: u32; + +fn get_index(r: u32, g: u32, b: u32) -> u32 { + return (r << (2 * INDEX_BITS)) + (r << (INDEX_BITS + 1)) + r + (g << INDEX_BITS) + g + b; +} + +fn volume(cube: Box, moment: ptr>) -> f32 { + return f32( + (*moment)[get_index(cube.r1, cube.g1, cube.b1)] - + (*moment)[get_index(cube.r1, cube.g1, cube.b0)] - + (*moment)[get_index(cube.r1, cube.g0, cube.b1)] + + (*moment)[get_index(cube.r1, cube.g0, cube.b0)] - + (*moment)[get_index(cube.r0, cube.g1, cube.b1)] + + (*moment)[get_index(cube.r0, cube.g1, cube.b0)] + + (*moment)[get_index(cube.r0, cube.g0, cube.b1)] - + (*moment)[get_index(cube.r0, cube.g0, cube.b0)] + ); +} + +fn variance(cube: Box) -> f32 { + let vol = volume(cube, &moments.w); + if (vol <= 1f) { + return 0f; + } + let dr = volume(cube, &moments.r); + let dg = volume(cube, &moments.g); + let db = volume(cube, &moments.b); + let xx = moments.quad[get_index(cube.r1, cube.g1, cube.b1)] - + moments.quad[get_index(cube.r1, cube.g1, cube.b0)] - + moments.quad[get_index(cube.r1, cube.g0, cube.b1)] + + moments.quad[get_index(cube.r1, cube.g0, cube.b0)] - + moments.quad[get_index(cube.r0, cube.g1, cube.b1)] + + moments.quad[get_index(cube.r0, cube.g1, cube.b0)] + + moments.quad[get_index(cube.r0, cube.g0, cube.b1)] - + moments.quad[get_index(cube.r0, cube.g0, cube.b0)]; + let hypotenuse = dr * dr + dg * dg + db * db; + return xx - hypotenuse / vol; +} + +fn bottom(cube: Box, dir: u32, moment: ptr>) -> f32 { + if (dir == 0) { + return f32( + (*moment)[get_index(cube.r0, cube.g1, cube.b0)] - + (*moment)[get_index(cube.r0, cube.g1, cube.b1)] + + (*moment)[get_index(cube.r0, cube.g0, cube.b1)] - + (*moment)[get_index(cube.r0, cube.g0, cube.b0)] + ); + } else if (dir == 1) { + return f32( + (*moment)[get_index(cube.r1, cube.g0, cube.b0)] - + (*moment)[get_index(cube.r1, cube.g0, cube.b1)] + + (*moment)[get_index(cube.r0, cube.g0, cube.b1)] - + (*moment)[get_index(cube.r0, cube.g0, cube.b0)] + ); + } else if (dir == 2) { + return f32( + (*moment)[get_index(cube.r1, cube.g0, cube.b0)] - + (*moment)[get_index(cube.r1, cube.g1, cube.b0)] + + (*moment)[get_index(cube.r0, cube.g1, cube.b0)] - + (*moment)[get_index(cube.r0, cube.g0, cube.b0)] + ); + } + return 0; +} + +fn top(cube: Box, dir: u32, cut: u32, moment: ptr>) -> f32 { + if (dir == 0) { + return f32( + (*moment)[get_index(cut, cube.g1, cube.b1)] - + (*moment)[get_index(cut, cube.g1, cube.b0)] - + (*moment)[get_index(cut, cube.g0, cube.b1)] + + (*moment)[get_index(cut, cube.g0, cube.b0)] + ); + } else if (dir == 1) { + return f32( + (*moment)[get_index(cube.r1, cut, cube.b1)] - + (*moment)[get_index(cube.r1, cut, cube.b0)] - + (*moment)[get_index(cube.r0, cut, cube.b1)] + + (*moment)[get_index(cube.r0, cut, cube.b0)] + ); + } else if (dir == 2) { + return f32( + (*moment)[get_index(cube.r1, cube.g1, cut)] - + (*moment)[get_index(cube.r1, cube.g0, cut)] - + (*moment)[get_index(cube.r0, cube.g1, cut)] + + (*moment)[get_index(cube.r0, cube.g0, cut)] + ); + } + return 0; +} + +struct MaxVarianceResult { + max_variance: f32, + max_variance_idx: u32, +} + +fn find_max_variance_cut(cuts_variances: ptr>, first: u32, last: u32) -> MaxVarianceResult { + var max_variance = (*cuts_variances)[first]; + var max_variance_idx = first; + + for (var i = first + 1; i < last; i++) { + if ((*cuts_variances)[i] > max_variance) { + max_variance = (*cuts_variances)[i]; + max_variance_idx = i; + } + } + + return MaxVarianceResult(max_variance, max_variance_idx); +} + +@compute @workgroup_size(3, 33) +fn cs(@builtin(global_invocation_id) id: vec3u) { + let channel = id.x; + let cut = id.y; + + let cube = cubes[current_cube_idx]; + var first = 0u; + var last = SIDE_LENGTH; + if (channel == 0) { + first = cube.r0 + 1; + last = cube.r1; + } else if (channel == 1) { + first = cube.g0 + 1; + last = cube.g1; + } else if (channel == 2) { + first = cube.b0 + 1; + last = cube.b1; + } + + if (cut >= first && cut < last && channel < 3) { + let whole = vec4f( + volume(cube, &moments.r), + volume(cube, &moments.g), + volume(cube, &moments.b), + volume(cube, &moments.w) + ); + + let bottom = vec4f( + bottom(cube, channel, &moments.r), + bottom(cube, channel, &moments.g), + bottom(cube, channel, &moments.b), + bottom(cube, channel, &moments.w) + ); + + let top = vec4f( + top(cube, channel, cut, &moments.r), + top(cube, channel, cut, &moments.g), + top(cube, channel, cut, &moments.b), + top(cube, channel, cut, &moments.w) + ); + + var half = bottom + top; + + var variance_sum = 0f; + if (half[3] > 0) { + variance_sum = (half[0] * half[0] + half[1] * half[1] + half[2] * half[2]) / half[3]; + + half = whole - half; + + if (half[3] > 0) { + variance_sum += (half[0] * half[0] + half[1] * half[1] + half[2] * half[2]) / half[3]; + } else { + variance_sum = 0f; + } + } + if (channel == 0) { + cut_variances_r[cut] = variance_sum; + } else if (channel == 1) { + cut_variances_g[cut] = variance_sum; + } else if (channel == 2) { + cut_variances_b[cut] = variance_sum; + } + } + + workgroupBarrier(); + + if (cut == 0) { + var result = MaxVarianceResult(0f, 0u); + + if (channel == 0) { + result = find_max_variance_cut(&cut_variances_r, first, last); + } else if (channel == 1) { + result = find_max_variance_cut(&cut_variances_g, first, last); + } else { + result = find_max_variance_cut(&cut_variances_b, first, last); + } + + best_cut[channel] = result.max_variance_idx; + + if (channel == 0) { + cut_variances_r[0] = result.max_variance; + } else if (channel == 1) { + cut_variances_g[0] = result.max_variance; + } else { + cut_variances_b[0] = result.max_variance; + } + } + + workgroupBarrier(); + + if (cut == 0 && channel == 0) { + let best_variance_r = cut_variances_r[0]; + let best_variance_g = cut_variances_g[0]; + let best_variance_b = cut_variances_b[0]; + + var direction = 0u; + if(best_variance_r > best_variance_g && best_variance_r > best_variance_b) { + direction = 0; + } else if (best_variance_g > best_variance_r && best_variance_g > best_variance_b) { + direction = 1; + } else { + direction = 2; + } + + let chosen_cut = best_cut[direction]; + var new_cube = cubes[total_cubes_num]; + new_cube.r1 = cubes[current_cube_idx].r1; + new_cube.g1 = cubes[current_cube_idx].g1; + new_cube.b1 = cubes[current_cube_idx].b1; + if (direction == 0) { + cubes[current_cube_idx].r1 = chosen_cut; + new_cube.r0 = chosen_cut; + new_cube.g0 = cube.g0; + new_cube.b0 = cube.b0; + } else if (direction == 1) { + cubes[current_cube_idx].g1 = chosen_cut; + new_cube.r0 = cube.r0; + new_cube.g0 = chosen_cut; + new_cube.b0 = cube.b0; + } else { + cubes[current_cube_idx].b1 = chosen_cut; + new_cube.r0 = cube.r0; + new_cube.g0 = cube.g0; + new_cube.b0 = chosen_cut; + } + + cubes[total_cubes_num] = new_cube; + + variances[current_cube_idx] = variance(cubes[current_cube_idx]); + variances[total_cubes_num] = variance(new_cube); + + var next_idx = 0u; + var next_variance = variances[0]; + for (var i = 0u; i <= total_cubes_num; i++) { + if (variances[i] > next_variance) { + next_variance = variances[i]; + next_idx = i; + } + } + + current_cube_idx = next_idx; + } +} diff --git a/typescript/quantize-webgpu/wu/shaders/create_result.wgsl b/typescript/quantize-webgpu/wu/shaders/create_result.wgsl new file mode 100644 index 00000000..ab8ab933 --- /dev/null +++ b/typescript/quantize-webgpu/wu/shaders/create_result.wgsl @@ -0,0 +1,72 @@ +const INDEX_BITS = 5u; +const SIDE_LENGTH = 33u; +const TOTAL_SIZE = 35937u; + +struct Box { + r0: u32, + r1: u32, + g0: u32, + g1: u32, + b0: u32, + b1: u32 +} + +struct Moments { + r: array, + g: array, + b: array, + w: array, + quad: array +} + +@group(0) @binding(0) var moments: Moments; + +@group(1) @binding(0) var cubes: array; +@group(1) @binding(1) var total_cubes_num: u32; + +@group(2) @binding(0) var results: array; + +fn get_index(r: u32, g: u32, b: u32) -> u32 { + return (r << (2 * INDEX_BITS)) + (r << (INDEX_BITS + 1)) + r + (g << INDEX_BITS) + g + b; +} + +fn volume(cube: Box, moment: ptr>) -> f32 { + return f32( + (*moment)[get_index(cube.r1, cube.g1, cube.b1)] - + (*moment)[get_index(cube.r1, cube.g1, cube.b0)] - + (*moment)[get_index(cube.r1, cube.g0, cube.b1)] + + (*moment)[get_index(cube.r1, cube.g0, cube.b0)] - + (*moment)[get_index(cube.r0, cube.g1, cube.b1)] + + (*moment)[get_index(cube.r0, cube.g1, cube.b0)] + + (*moment)[get_index(cube.r0, cube.g0, cube.b1)] - + (*moment)[get_index(cube.r0, cube.g0, cube.b0)] + ); +} + +@compute @workgroup_size(3, 32) +fn cs(@builtin(global_invocation_id) id: vec3u) { + let channel = id.x; + let cube_idx = id.y; + + if (cube_idx > total_cubes_num) { + return; + } + + let cube = cubes[cube_idx]; + let weight = volume(cube, &moments.w); + + if (weight > 0) { + if (channel == 0) { + let r = volume(cube, &moments.r) / weight; + results[cube_idx * 3 + 0] = r / 255.0; + } else if (channel == 1) { + let g = volume(cube, &moments.g) / weight; + results[cube_idx * 3 + 1] = g / 255.0; + } else { + let b = volume(cube, &moments.b) / weight; + results[cube_idx * 3 + 2] = b / 255.0; + } + } else { + results[cube_idx * 3 + channel] = -1.0; + } +} diff --git a/typescript/spec/support/jasmine-browser.mjs b/typescript/spec/support/jasmine-browser.mjs new file mode 100644 index 00000000..0f5f621e --- /dev/null +++ b/typescript/spec/support/jasmine-browser.mjs @@ -0,0 +1,30 @@ +export default { + srcDir: "src", + srcFiles: [], + specDir: ".", + specFiles: [ + "quantize-webgpu/**/*_test.js" + ], + helpers: [ + ], + esmFilenameExtension: ".js", + enableTopLevelAwait: false, + env: { + stopSpecOnExpectationFailure: false, + stopOnSpecFailure: false, + random: true + }, + listenAddress: "localhost", + + hostname: "localhost", + + browser: { + name: "chrome", + flags: [ + "--enable-unsafe-webgpu", + "--enable-features=Vulkan,UseSkiaRenderer", + "--enable-dawn-features=allow_unsafe_apis" + ] + }, + moduleType: "module" +}; diff --git a/typescript/tsconfig.json b/typescript/tsconfig.json index f603e467..69a6d1e0 100644 --- a/typescript/tsconfig.json +++ b/typescript/tsconfig.json @@ -8,7 +8,7 @@ "emitDecoratorMetadata": true, "experimentalDecorators": true, "importHelpers": true, - "module": "es2015", + "module": "es2020", "moduleResolution": "node", "noFallthroughCasesInSwitch": true, "noImplicitAny": true, @@ -23,13 +23,16 @@ "strictNullChecks": false, "target": "es2020", "types": [ - "jasmine" - ] + "jasmine", + "@webgpu/types" + ], + "skipLibCheck": true }, "include": [ "**/*.ts" ], "exclude": [ - "**/*_test.ts" + // "**/*_test.ts", + "node_modules" ] } \ No newline at end of file diff --git a/typescript/utils/color_utils.ts b/typescript/utils/color_utils.ts index c3db1860..a374c238 100644 --- a/typescript/utils/color_utils.ts +++ b/typescript/utils/color_utils.ts @@ -297,3 +297,26 @@ function labInvf(ft: number): number { return (116 * ft - 16) / kappa; } } + +export function floatArrayToHex(colors: Float32Array): string[] { + const hexColors: string[] = []; + for (let i = 0; i < colors.length; i += 3) { + const r = Math.round(colors[i] * 255).toString(16).padStart(2, '0'); + const g = Math.round(colors[i + 1] * 255).toString(16).padStart(2, '0'); + const b = Math.round(colors[i + 2] * 255).toString(16).padStart(2, '0'); + hexColors.push(`#${r}${g}${b}`); + } + return hexColors; +} + +export function pixelsToTextureData(pixels: number[]): Uint8Array { + const textureData = new Uint8Array(pixels.length * 4); + for (let i = 0; i < pixels.length; i++) { + const pixel = pixels[i]; + textureData[i * 4] = redFromArgb(pixel); + textureData[i * 4 + 1] = greenFromArgb(pixel); + textureData[i * 4 + 2] = blueFromArgb(pixel); + textureData[i * 4 + 3] = alphaFromArgb(pixel); + } + return textureData; +}