软工划水日报-安卓端侧部署(3) 4/25

今天我们尝试了paddle-lite框架,发现意外可行

那么我们来写一下图像预处理函数:

package com.example.ironfarm;

import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.util.Log;

import com.baidu.paddle.lite.MobileConfig;
import com.baidu.paddle.lite.PaddlePredictor;
import com.baidu.paddle.lite.PowerMode;
import com.baidu.paddle.lite.Tensor;

import java.io.File;
import java.io.FileInputStream;
import java.util.Arrays;

public class PaddleLiteClassification {
    private static final String TAG = PaddleLiteClassification.class.getName();

    private PaddlePredictor paddlePredictor;
    private Tensor inputTensor;
    private long[] inputShape = new long[]{1, 3, 224, 224};
    private static float[] scale = new float[]{1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
    private static float[] inputMean = new float[]{0.485f, 0.456f, 0.406f};
    private static float[] inputStd = new float[]{0.229f, 0.224f, 0.225f};
    private static final int NUM_THREADS = 4;

    /**
     * @param modelPath model path
     */
    public PaddleLiteClassification(String modelPath) throws Exception {
        File file = new File(modelPath);
        if (!file.exists()) {
            throw new Exception("model file is not exists!");
        }
        try {
            MobileConfig config = new MobileConfig();
            config.setModelFromFile(modelPath);
            config.setThreads(NUM_THREADS);
            config.setPowerMode(PowerMode.LITE_POWER_HIGH);
            paddlePredictor = PaddlePredictor.createPaddlePredictor(config);

            inputTensor = paddlePredictor.getInput(0);
            inputTensor.resize(inputShape);
        } catch (Exception e) {
            e.printStackTrace();
            throw new Exception("load model fail!");
        }
    }

    public float[] predictImage(String image_path) throws Exception {
        if (!new File(image_path).exists()) {
            throw new Exception("image file is not exists!");
        }
        FileInputStream fis = new FileInputStream(image_path);
        Bitmap bitmap = BitmapFactory.decodeStream(fis);
        float[] result = predictImage(bitmap);
        if (bitmap.isRecycled()) {
            bitmap.recycle();
        }
        return result;
    }

    public float[] predictImage(Bitmap bitmap) throws Exception {
        return predict(bitmap);
    }

    public static int getMaxResult(float[] result) {
        float probability = 0;
        int r = 0;
        for (int i = 0; i < result.length; i++) {
            if (probability < result[i]) {
                probability = result[i];
                r = i;
            }
        }
        return r;
    }

    private static float[] getScaledMatrix(Bitmap bitmap, int desWidth, int desHeight) {
        float[] dataBuf = new float[3 * desWidth * desHeight];
        int rIndex;
        int gIndex;
        int bIndex;
        int[] pixels = new int[desWidth * desHeight];
        Bitmap bm = Bitmap.createScaledBitmap(bitmap, desWidth, desHeight, false);
        bm.getPixels(pixels, 0, desWidth, 0, 0, desWidth, desHeight);
        int j = 0;
        int k = 0;
        for (int i = 0; i < pixels.length; i++) {
            int clr = pixels[i];
            j = i / desHeight;
            k = i % desWidth;
            rIndex = j * desWidth + k;
            gIndex = rIndex + desHeight * desWidth;
            bIndex = gIndex + desHeight * desWidth;
            // 转成RGB通道顺序
            dataBuf[bIndex] = (float) (((clr & 0x00ff0000) >> 16) / 255.0);
            dataBuf[gIndex] = (float) (((clr & 0x0000ff00) >> 8) / 255.0);
            dataBuf[rIndex] = (float) (((clr & 0x000000ff)) / 255.0);
        }
        if (bm.isRecycled()) {
            bm.recycle();
        }
        Log.d("sss",  Arrays.toString(dataBuf));
        return dataBuf;
    }

    private Bitmap getScaleBitmap(Bitmap bitmap) {
        int bmpWidth = bitmap.getWidth();
        int bmpHeight = bitmap.getHeight();
        int size = (int) inputShape[2];
        float scaleWidth = (float) size / bitmap.getWidth();
        float scaleHeight = (float) size / bitmap.getHeight();
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        return Bitmap.createBitmap(bitmap, 0, 0, bmpWidth, bmpHeight, matrix, true);
    }

    private float[] predict(Bitmap bmp) throws Exception {
        Bitmap b = getScaleBitmap(bmp);
        float[] inputData = getScaledMatrix(b, (int) inputShape[2], (int) inputShape[3]);
        b.recycle();
        bmp.recycle();
        inputTensor.setData(inputData);

        try {
            paddlePredictor.run();
        } catch (Exception e) {
            throw new Exception("predict image fail! log:" + e);
        }
        Tensor outputTensor = paddlePredictor.getOutput(0);
        float[] result = outputTensor.getFloatData();
        Log.d(TAG, Arrays.toString(result));
        int l = getMaxResult(result);
        return new float[]{l, result[l]};
//        return result;
    }
}

这个函数的核心就是对拍照/相册的图像进行各种预处理转化成float数组……其实是通过rgb颜色来进行的预测

好,今天就努力到这里啦

原文地址:https://www.cnblogs.com/Sakuraba/p/14910214.html