Skip to content

Commit f6714a6

Browse files
authored
Add TFLite + Transformer example (#179)
* Add TFLite + Transformer example * Add .tflite to .gitignore * Disable radio buttons when export starts, StyleTransferShaderProgram: Remove unused code, added TAG constant, and added TODO for migrating to Kotlin
1 parent c1fd233 commit f6714a6

17 files changed

+680
-5
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@ docs-gen
5151
site
5252
*.bak
5353
.idea/appInsightsSettings.xml
54+
55+
#TFLite
56+
*.tflite

build.gradle.kts

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
buildscript {
17+
dependencies {
18+
classpath("de.undercouch:gradle-download-task:4.1.2")
19+
}
20+
}
21+
1622
@Suppress("DSL_SCOPE_VIOLATION")
1723
plugins {
1824
alias(libs.plugins.android.application) apply false
@@ -38,7 +44,7 @@ versionCatalogUpdate {
3844
affectedModuleDetector {
3945
baseDir = "${project.rootDir}"
4046
pathsAffectingAllModules = setOf(
41-
"gradle/libs.versions.toml",
47+
"gradle/libs.versions.toml",
4248
)
4349
excludedModules = setOf<String>()
4450

gradle/libs.versions.toml

+8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ material = "1.12.0-beta01"
4444
constraintlayout = "2.1.4"
4545
glide-compose = "1.0.0-beta01"
4646
glance = "1.1.0-SNAPSHOT"
47+
tensorflowLite = "2.9.0"
48+
tensorflowLiteGpuDelegatePlugin = "0.4.4"
49+
tensorflowLiteSupport = "0.4.2"
4750

4851
[libraries]
4952

@@ -155,6 +158,11 @@ glide-compose = { group = "com.github.bumptech.glide", name = "compose", version
155158
appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
156159
material = { group = "com.google.android.material", name = "material", version.ref = "material" }
157160
constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraintlayout" }
161+
tensorflow-lite = { module = "org.tensorflow:tensorflow-lite", version.ref = "tensorflowLite" }
162+
tensorflow-lite-gpu = { module = "org.tensorflow:tensorflow-lite-gpu", version.ref = "tensorflowLite" }
163+
tensorflow-lite-gpu-delegate-plugin = { module = "org.tensorflow:tensorflow-lite-gpu-delegate-plugin", version.ref = "tensorflowLiteGpuDelegatePlugin" }
164+
tensorflow-lite-select-tf-ops = { module = "org.tensorflow:tensorflow-lite-select-tf-ops", version.ref = "tensorflowLite" }
165+
tensorflow-lite-support = { module = "org.tensorflow:tensorflow-lite-support", version.ref = "tensorflowLiteSupport" }
158166

159167
[plugins]
160168
affectedmoduledetector = { id = "com.dropbox.affectedmoduledetector", version = "0.2.0" }

samples/README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,16 @@ The sample demonstrates the importance of having proper labels for
116116
A sample showcasing how to handle calls with the Jetpack Telecom API
117117
- [TextSpan](user-interface/text/src/main/java/com/example/platform/ui/text/TextSpan.kt):
118118
buildSpannedString is useful for quickly building a rich text.
119+
- [Transformer and TFLite](media/video/src/main/java/com/example/platform/media/video/TransformerTFLite.kt):
120+
This sample demonstrates using Transformer with TFLite by applying a selected art style to a video.
119121
- [UltraHDR Image Capture](camera/camera2/src/main/java/com/example/platform/camera/imagecapture/Camera2UltraHDRCapture.kt):
120122
This sample demonstrates how to capture a 10-bit compressed still image and
121123
- [UltraHDR to HDR Video](media/ultrahdr/src/main/java/com/example/platform/media/ultrahdr/video/UltraHDRToHDRVideo.kt):
122124
This sample demonstrates converting a series of UltraHDR images into a HDR
123125
- [UltraHDR x OpenGLES SurfaceView](graphics/ultrahdr/src/main/java/com/example/platform/graphics/ultrahdr/opengl/UltraHDRWithOpenGL.kt):
124126
This sample demonstrates displaying an UltraHDR image via and OpenGL Pipeline
125127
- [Video Composition using Media3 Transformer](media/video/src/main/java/com/example/platform/media/video/TransformerVideoComposition.kt):
126-
This sample demonstrates concatenation of two video assets and an image using Media3.
128+
This sample demonstrates concatenation of two video assets and an image using Media3 Transformer library.
127129
- [Visualizing an UltraHDR Gainmap](graphics/ultrahdr/src/main/java/com/example/platform/graphics/ultrahdr/display/VisualizingAnUltraHDRGainmap.kt):
128130
This sample demonstrates visualizing the underlying gainmap of an UltraHDR
129131
- [WindowInsetsAnimation](user-interface/window-insets/src/main/java/com/example/platform/ui/insets/WindowInsetsAnimation.kt):

samples/media/video/build.gradle.kts

+21
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,25 @@
1616

1717
plugins {
1818
id("com.example.platform.sample")
19+
id("de.undercouch.download")
1920
}
2021

2122
android {
2223
namespace = "com.example.platform.media.video"
2324
viewBinding.isEnabled = true
25+
26+
androidResources {
27+
noCompress += "tflite"
28+
}
29+
}
30+
31+
// Import DownloadModels task for TFLite sample
32+
project.ext.set("ASSET_DIR", "$projectDir/src/main/assets")
33+
project.ext.set("TEST_ASSETS_DIR", "$projectDir/src/androidTest/assets")
34+
// Download default models; if you wish to use your own models then
35+
// place them in the "assets" directory and comment out this line.
36+
apply {
37+
from("download_model.gradle")
2438
}
2539

2640
dependencies {
@@ -37,4 +51,11 @@ dependencies {
3751
implementation(libs.androidx.media3.ui)
3852
implementation(libs.androidx.media3.effect)
3953
implementation(libs.material)
54+
55+
// Tensorflow lite dependencies
56+
implementation(libs.tensorflow.lite)
57+
implementation(libs.tensorflow.lite.gpu)
58+
implementation(libs.tensorflow.lite.gpu.delegate.plugin)
59+
implementation(libs.tensorflow.lite.support)
60+
implementation(libs.tensorflow.lite.select.tf.ops)
4061
}
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Copyright 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
tasks.register('downloadModelFile1', Download) {
17+
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_fp16_prediction_1.tflite'
18+
dest project.ext.ASSET_DIR + '/predict_float16.tflite'
19+
overwrite false
20+
}
21+
22+
tasks.register('downloadModelFile2', Download) {
23+
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/task_library/style_transfer/android/magenta_arbitrary-image-stylization-v1-256_fp16_transfer_1.tflite'
24+
dest project.ext.ASSET_DIR + '/transfer_float16.tflite'
25+
overwrite false
26+
}
27+
28+
preBuild.dependsOn downloadModelFile1, downloadModelFile2
205 KB
Loading
122 KB
Loading
112 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Copyright 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.example.platform.media.video;
18+
19+
import android.content.Context;
20+
import android.graphics.Bitmap;
21+
import android.graphics.BitmapFactory;
22+
import android.opengl.GLES20;
23+
import android.opengl.GLUtils;
24+
import android.util.Log;
25+
26+
import androidx.media3.common.VideoFrameProcessingException;
27+
import androidx.media3.common.util.GlProgram;
28+
import androidx.media3.common.util.GlUtil;
29+
import androidx.media3.common.util.Size;
30+
import androidx.media3.common.util.UnstableApi;
31+
import androidx.media3.effect.BaseGlShaderProgram;
32+
33+
import com.google.common.collect.ImmutableMap;
34+
35+
import org.tensorflow.lite.DataType;
36+
import org.tensorflow.lite.Interpreter;
37+
import org.tensorflow.lite.InterpreterApi;
38+
import org.tensorflow.lite.gpu.CompatibilityList;
39+
import org.tensorflow.lite.gpu.GpuDelegate;
40+
import org.tensorflow.lite.support.common.FileUtil;
41+
import org.tensorflow.lite.support.common.ops.DequantizeOp;
42+
import org.tensorflow.lite.support.common.ops.NormalizeOp;
43+
import org.tensorflow.lite.support.image.ImageProcessor;
44+
import org.tensorflow.lite.support.image.TensorImage;
45+
import org.tensorflow.lite.support.image.ops.ResizeOp;
46+
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
47+
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
48+
49+
import java.io.IOException;
50+
import java.io.InputStream;
51+
import java.nio.ByteBuffer;
52+
53+
import javax.microedition.khronos.opengles.GL10;
54+
55+
// TODO: Migrate this class to Kotlin
56+
@UnstableApi
57+
final class StyleTransferShaderProgram extends BaseGlShaderProgram {
58+
59+
private static final String TAG = "StyleTransferSP";
60+
private static final String VERTEX_SHADER_PATH = "shaders/vertex_shader_transformation_es2.glsl";
61+
private static final String FRAGMENT_SHADER_PATH = "shaders/fragment_shader_copy_es2.glsl";
62+
63+
private final GlProgram glProgram;
64+
private final InterpreterApi transformInterpreter;
65+
private final int inputTransformTargetHeight;
66+
private final int inputTransformTargetWidth;
67+
private final int[] outputTransformShape;
68+
69+
private final TensorBuffer predictOutput;
70+
71+
private int width;
72+
private int height;
73+
74+
public StyleTransferShaderProgram(Context context, String styleAssetFileName)
75+
throws VideoFrameProcessingException {
76+
super(/* useHighPrecisionColorComponents= */ false, /* texturePoolCapacity= */ 1);
77+
78+
try {
79+
glProgram = new GlProgram(context, VERTEX_SHADER_PATH, FRAGMENT_SHADER_PATH);
80+
81+
Interpreter.Options options = new Interpreter.Options();
82+
83+
CompatibilityList compatibilityList = new CompatibilityList();
84+
if (compatibilityList.isDelegateSupportedOnThisDevice()) {
85+
GpuDelegate.Options gpuDelegateOptions = compatibilityList.getBestOptionsForThisDevice();
86+
GpuDelegate gpuDelegate = new GpuDelegate(gpuDelegateOptions);
87+
options.addDelegate(gpuDelegate);
88+
} else {
89+
options.setNumThreads(6);
90+
}
91+
String predictModel = "predict_float16.tflite";
92+
String transferModel = "transfer_float16.tflite";
93+
Interpreter predictInterpeter =
94+
new Interpreter(FileUtil.loadMappedFile(context, predictModel), options);
95+
transformInterpreter =
96+
InterpreterApi.create(FileUtil.loadMappedFile(context, transferModel), options);
97+
int inputPredictTargetHeight = predictInterpeter.getInputTensor(0).shape()[1];
98+
int inputPredictTargetWidth = predictInterpeter.getInputTensor(0).shape()[2];
99+
int[] outputPredictShape = predictInterpeter.getOutputTensor(0).shape();
100+
101+
inputTransformTargetHeight = transformInterpreter.getInputTensor(0).shape()[1];
102+
inputTransformTargetWidth = transformInterpreter.getInputTensor(0).shape()[2];
103+
outputTransformShape = transformInterpreter.getOutputTensor(0).shape();
104+
105+
InputStream inputStream = context.getAssets().open(styleAssetFileName);
106+
Bitmap styleImage = BitmapFactory.decodeStream(inputStream);
107+
inputStream.close();
108+
TensorImage styleTensorImage =
109+
getScaledTensorImage(styleImage, inputPredictTargetWidth, inputPredictTargetHeight);
110+
predictOutput = TensorBuffer.createFixedSize(outputPredictShape, DataType.FLOAT32);
111+
predictInterpeter.run(styleTensorImage.getBuffer(), predictOutput.getBuffer());
112+
} catch (IOException | GlUtil.GlException e) {
113+
Log.w(TAG, "Error setting up TfShaderProgram", e);
114+
throw new VideoFrameProcessingException(e);
115+
}
116+
}
117+
118+
@Override
119+
public Size configure(int inputWidth, int inputHeight) {
120+
width = inputWidth;
121+
height = inputHeight;
122+
return new Size(inputWidth, inputHeight);
123+
}
124+
125+
@Override
126+
public void drawFrame(int inputTexId, long presentationTimeUs)
127+
throws VideoFrameProcessingException {
128+
ByteBuffer pixelBuffer = ByteBuffer.allocateDirect(width * height * 4);
129+
130+
Bitmap bitmap;
131+
int texId;
132+
try {
133+
int[] boundFramebuffer = new int[1];
134+
GLES20.glGetIntegerv(GLES20.GL_FRAMEBUFFER_BINDING, boundFramebuffer, /* offset= */ 0);
135+
136+
int fboId = GlUtil.createFboForTexture(inputTexId);
137+
GlUtil.focusFramebufferUsingCurrentContext(fboId, width, height);
138+
GLES20.glReadPixels(
139+
/* x= */ 0,
140+
/* y= */ 0,
141+
width,
142+
height,
143+
GLES20.GL_RGBA,
144+
GLES20.GL_UNSIGNED_BYTE,
145+
pixelBuffer);
146+
GlUtil.checkGlError();
147+
bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
148+
bitmap.copyPixelsFromBuffer(pixelBuffer);
149+
150+
Log.w(TAG, "Process frame at " + (presentationTimeUs / 1000) + " ms");
151+
long before = System.currentTimeMillis();
152+
TensorImage tensorImage =
153+
getScaledTensorImage(bitmap, inputTransformTargetWidth, inputTransformTargetHeight);
154+
Log.w(TAG, "- Scale " + (System.currentTimeMillis() - before) + " ms");
155+
TensorBuffer outputImage =
156+
TensorBuffer.createFixedSize(outputTransformShape, DataType.FLOAT32);
157+
158+
before = System.currentTimeMillis();
159+
transformInterpreter.runForMultipleInputsOutputs(
160+
new Object[] {tensorImage.getBuffer(), predictOutput.getBuffer()},
161+
ImmutableMap.<Integer, Object>builder().put(0, outputImage.getBuffer()).build());
162+
163+
Log.w(TAG, "- Run " + (System.currentTimeMillis() - before) + " ms");
164+
165+
before = System.currentTimeMillis();
166+
ImageProcessor imagePostProcessor =
167+
new ImageProcessor.Builder()
168+
.add(new DequantizeOp(/* zeroPoint= */ 0f, /* scale= */ 255f))
169+
.build();
170+
TensorImage outputTensorImage = new TensorImage(DataType.FLOAT32);
171+
outputTensorImage.load(outputImage);
172+
Log.w(TAG, "- Load output " + (System.currentTimeMillis() - before) + " ms");
173+
174+
before = System.currentTimeMillis();
175+
Bitmap outputBitmap = imagePostProcessor.process(outputTensorImage).getBitmap();
176+
Log.w(TAG, "- Post process output " + (System.currentTimeMillis() - before) + " ms");
177+
178+
texId =
179+
GlUtil.createTexture(
180+
outputBitmap.getWidth(),
181+
outputBitmap.getHeight(),
182+
/* useHighPrecisionColorComponents= */ false);
183+
GLES20.glBindTexture(GLES20.GL_TEXTURE_2D, texId);
184+
GLES20.glTexParameterf(GL10.GL_TEXTURE_2D, GL10.GL_TEXTURE_MIN_FILTER, GL10.GL_NEAREST);
185+
GLES20.glTexParameterf(GL10.GL_TEXTURE_2D, GL10.GL_TEXTURE_MAG_FILTER, GL10.GL_LINEAR);
186+
GLES20.glTexParameterf(GL10.GL_TEXTURE_2D, GL10.GL_TEXTURE_WRAP_S, GL10.GL_REPEAT);
187+
GLES20.glTexParameterf(GL10.GL_TEXTURE_2D, GL10.GL_TEXTURE_WRAP_T, GL10.GL_REPEAT);
188+
GLUtils.texImage2D(GLES20.GL_TEXTURE_2D, /* level= */ 0, outputBitmap, /* border= */ 0);
189+
GlUtil.checkGlError();
190+
191+
GlUtil.focusFramebufferUsingCurrentContext(boundFramebuffer[0], width, height);
192+
193+
glProgram.use();
194+
glProgram.setSamplerTexIdUniform("uTexSampler", texId, /* texUnitIndex= */ 0);
195+
float[] identityMatrix = GlUtil.create4x4IdentityMatrix();
196+
glProgram.setFloatsUniform("uTexTransformationMatrix", identityMatrix);
197+
glProgram.setFloatsUniform("uTransformationMatrix", identityMatrix);
198+
glProgram.setBufferAttribute(
199+
"aFramePosition",
200+
GlUtil.getNormalizedCoordinateBounds(),
201+
GlUtil.HOMOGENEOUS_COORDINATE_VECTOR_SIZE);
202+
glProgram.bindAttributesAndUniforms();
203+
204+
GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, /* first= */ 0, /* count= */ 4);
205+
GlUtil.checkGlError();
206+
207+
GlUtil.deleteTexture(texId);
208+
} catch (GlUtil.GlException e) {
209+
throw VideoFrameProcessingException.from(e);
210+
}
211+
}
212+
213+
private static TensorImage getScaledTensorImage(
214+
Bitmap bitmap, int targetWidth, int targetHeight) {
215+
int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());
216+
ImageProcessor imageProcessor =
217+
new ImageProcessor.Builder()
218+
.add(new ResizeWithCropOrPadOp(cropSize, cropSize))
219+
.add(
220+
new ResizeOp(
221+
targetHeight,
222+
targetWidth,
223+
ResizeOp.ResizeMethod.BILINEAR)) // TODO: Not sure why they are swapped?
224+
.add(new NormalizeOp(/* mean= */ 0f, /* stddev= */ 255f))
225+
.build();
226+
TensorImage tensorImage = new TensorImage(DataType.FLOAT32);
227+
tensorImage.load(bitmap);
228+
return imageProcessor.process(tensorImage);
229+
}
230+
}

0 commit comments

Comments
 (0)