使用TensorFlow Lite的三个入门问题

Preface

最近开始上手TensorFlow做项目,之前总是看书觉得停留在表面,这次实战果然遇到了书本上遇不到的问题。

我把问题总结下来,事后来看其实都是小事,不过在刚开始时也是破费时间,也许其他人也会遇到,于是就发表出来。

Problem 1: 加载模型文件失败

在Andriod上使用TensorFlow Lite,处置模型文件的最常见做法,就是放在assets目录下,于是新手就会常常遇到下面的问题:

java.lang.RuntimeException: java.lang.reflect.InvocationTargetException
    at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:502)
    at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)
 Caused by: java.lang.reflect.InvocationTargetException
    at java.lang.reflect.Method.invoke(Native Method)
    at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492)
    at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930) 
 Caused by: java.io.FileNotFoundException: This file can not be opened as a file descriptor; it is probably compressed
    at android.content.res.AssetManager.nativeOpenAssetFd(Native Method)
    at android.content.res.AssetManager.openFd(AssetManager.java:848)
    at org.tensorflow.lite.support.common.FileUtil.loadMappedFile(FileUtil.java:74)

这个错误描述非常清晰,它甚至都给出了猜测:“it is probably compressed”。

事实上也确实如此,AAPT对于asset目录下的资源也是默认压缩的(raw子目录除外)。

1
2
3
4
5
public static MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath) throws IOException {
    AssetFileDescriptor fileDescriptor = context.getAssets().openFd(filePath);    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    //.......
}

在上述代码中,openFd函数需要明确待打开的资源文件,必须得是未压缩

Open an uncompressed asset by mmapping it and returning an AssetFileDescriptor.

要避免资源被压缩,解决方法也有现成的,那就是在gradle编译文件里,增加aaptOptionnoCompress属性:

Extensions of files that will not be stored compressed in the APK. Adding an empty extension, i.e., setting noCompress ” will trivially disable compression for all files.

简单的修改如下:

android {
    aaptOptions {
        noCompress "tflite"
    }
}

本节参考:

Problem 2: 缓冲区错误

在处理图片时,加载完模型,然后读取图片进行,此时容易把图片尺寸搞错,从而引发缓冲区错误,如下:

java.lang.IllegalArgumentException: Cannot convert between a TensorFlowLite buffer with 602112 bytes and a ByteBuffer with 4915200 bytes.
    at org.tensorflow.lite.Tensor.throwIfShapeIsIncompatible(Tensor.java:272)
    at org.tensorflow.lite.Tensor.throwIfDataIsIncompatible(Tensor.java:249)
    at org.tensorflow.lite.Tensor.setTo(Tensor.java:110)
    at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:145)
    at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:275)
    at org.tensorflow.lite.Interpreter.run(Interpreter.java:249)

问题的可能原因之一在于,ImaggeProcessor的操作符ResizeOp设置了错误的宽度和高度。

val imageProcessor: ImageProcessor = ImageProcessor.Builder()
        //......
        .add(ResizeOp(imageSizeY, imageSizeX, ResizeMethod.NEAREST_NEIGHBOR))
        //......
        .build()

ResizeOp的理解,可以参考源码或者文档:

1
2
3
4
5
6
7
8
/**
 * Creates a ResizeOp which can resize images to specified size in specified method.
 *
 * @param targetHeight: The expected height of resized image.
 * @param targetWidth: The expected width of resized image.
 * @param resizeMethod: The algorithm to use for resizing. Options: {@link ResizeMethod}
 */
public ResizeOp(int targetHeight, int targetWidth, ResizeMethod resizeMethod)

此处targetHeighttargetWidth,并不是随意设置的,是需要从模型文件中解析出来:

1
2
3
val imageShape = tflite.getInputTensor(imageTensorIndex).shape() // {1, height, width, 3}
imageSizeY = imageShape[1]
imageSizeX = imageShape[2]

Quesion 3: 类型不匹配

在执行Interpreter的方法run时,光看文档是非常容易犯错的,比如下面的错误:

java.lang.IllegalArgumentException: DataType error: cannot resolve DataType of org.tensorflow.lite.support.tensorbuffer.TensorBufferFloat
    at org.tensorflow.lite.Tensor.dataTypeOf(Tensor.java:199)
    at org.tensorflow.lite.Tensor.throwIfTypeIsIncompatible(Tensor.java:257)
    at org.tensorflow.lite.Tensor.throwIfDataIsIncompatible(Tensor.java:248)
    at org.tensorflow.lite.Tensor.copyTo(Tensor.java:141)
    at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:161)
    at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:275)
    at org.tensorflow.lite.Interpreter.run(Interpreter.java:249)

这个错误源于run方法的实参类型不对:

1
2
3
var outputBuffer: TensorBuffer
//......
tflite.run(inputImage.buffer, outputBuffer)

虽然形参类型是Object,但是实际上却是有着严格要求的,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
 /**
   * Runs model inference if the model takes only one input, and provides only one output.
   *
   * <p>Warning: The API is more efficient if a {@link Buffer} (preferably direct, but not required)
   * is used as the input/output data type. Please consider using {@link Buffer} to feed and fetch
   * primitive data for better performance. The following concrete {@link Buffer} types are
   * supported:
   *
   * <ul>
   *   <li>{@link ByteBuffer} - compatible with any underlying primitive Tensor type.
   *   <li>{@link FloatBuffer} - compatible with float Tensors.
   *   <li>{@link IntBuffer} - compatible with int32 Tensors.
   *   <li>{@link LongBuffer} - compatible with int64 Tensors.
   * </ul>
   * 
   * ......
   * @param output a multidimensional array of output data, or a {@link Buffer} of primitive types
   *     including int, float, long, and byte. When a {@link Buffer} is used, the caller must ensure
   *     that it is set the appropriate write position. A null value is allowed only if the caller
   *     is using a {@link Delegate} that allows buffer handle interop, and such a buffer has been
   *     bound to the output {@link Tensor}. See {@link Options#setAllowBufferHandleOutput()}.
   * ......
   */
  public void run(Object input, Object output) {

以我的理解,我觉得run方法最好是要重写,使用不同的参数类型来定义不同的方法签名,只有这样才算是对开发者友好。

Summary

上述三个问题,是我首次使用TF Lite时遇到的,把它写出来或许对其他人有帮助。 后期我会陆续把深入使用遇到的典型问题,继续分享出来。

Leave a comment

Your comment