tensorflow hub in flutter_android

55 views Asked by At

I want to use the model food_V1, https://tfhub.dev/google/aiy/vision/classifier/food_V1/1 in Flutter, but when I download the model, there is no labels.txt file. There is only one tflite file and I even extracted the labels from the model and used it in my project. The model does not run in the program and exits the program. The code is as follows:

import 'dart:io';

import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite/tflite.dart';

void main() => runApp(MyApp());

class MyApp extends StatelessWidget {
  @override
  Widget build(BuildContext context) {
    return MaterialApp(
        home: Scaffold(
            appBar: AppBar(
                title: Text('Tensor Flow Example'),
                backgroundColor: Color.fromRGBO(61, 62, 61, 1)),
            body: Center(child: HomePage())));
  }
}

class HomePage extends StatefulWidget {
  @override
  HomePageState createState() => HomePageState();
}

class HomePageState extends State {
  File? imageURI;
  var result;
  String? path;

  Future getImageFromCamera() async {
    try {
      var image = await ImagePicker().pickImage(source: ImageSource.camera);
      File img = File(image!.path);
      setState(() {
        imageURI = img;
        path = image.path;
      });
    } catch (e) {
      print("Error getting image from camera");
    }

    classifyImage();
  }

  Future getImageFromGallery() async {
    try {
      var image = await ImagePicker().pickImage(source: ImageSource.gallery);
      File img = File(image!.path);
      setState(() {
        imageURI = img;
        path = image.path;
      });
    } catch (e) {
      print("Error getting image from gallery");
    }

    classifyImage();
  }

  Future classifyImage() async {
    await Tflite.loadModel(
        model: "assets/model_unquant.tflite", labels: "assets/labels.txt");
    var output = await Tflite.runModelOnImage(
      path: imageURI!.path,
      imageMean: 127.5,
      imageStd: 127.5,
      numResults: 2,
      threshold: 0.2,
      );

    setState(() {
      result = output;
      print(output![0]["confidence"].toStringAsFixed(3));
    });
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
        backgroundColor: Color.fromRGBO(48, 48, 48, 1),
        body: Center(
            child: SingleChildScrollView(
              child: Column(
                  mainAxisAlignment: MainAxisAlignment.center,
                  children: <Widget>[
                    imageURI == null
                        ? Text(
                      'please select image from gallery or camera.',
                      style: TextStyle(color: Colors.white),
                    )
                        : ClipRRect(
                      borderRadius: BorderRadius.circular(8.0),
                      child: Image.file(imageURI!,
                          width: 300, height: 200, fit: BoxFit.cover),
                    ),
                    Row(
                      mainAxisAlignment: MainAxisAlignment.center,
                      children: [
                        Container(
                            margin: EdgeInsets.fromLTRB(0, 30, 0, 20),
                            child: IconButton(
                              onPressed: () => getImageFromCamera(),
                              icon: Icon(
                                Icons.camera,
                                color: Colors.white,
                              ),
                              iconSize: 50,
                              color: Colors.blue,
                              padding: EdgeInsets.fromLTRB(12, 12, 12, 12),
                            )),
                        Container(
                            margin: EdgeInsets.fromLTRB(0, 0, 0, 0),
                            child: IconButton(
                              icon: Icon(Icons.image, color: Colors.white),
                              iconSize: 50,
                              onPressed: () => getImageFromGallery(),
                              color: Colors.blue,
                              padding: EdgeInsets.fromLTRB(12, 12, 12, 12),
                            )),
                      ],
                    ),
                    SizedBox(
                      height: 10,
                    ),
                    result == null
                        ? Text(
                      'Result',
                      style: TextStyle(color: Colors.white),
                    )
                        : Container(
                      child: Text(
                        "${result[0]['label'].substring(2)}",
                        style: TextStyle(
                            color: Colors.white,
                            fontWeight: FontWeight.bold,
                            fontSize: 30),
                      ),
                    )
                  ]),
            )));
  }
}

I downloaded the model and separated its tags and added them to the assets in the pubspec.yaml file, and when I use the model in the program, it exits the program, but if instead of this model, a tflite model from teachable machine is used. I put it in the project and the code runs correctly

1

There are 1 answers

0
Chathura Chamikara On

It seems you are not resizing the picked image. The model you have picked from tf hub requires a specific input shape (1,224,224,3). Try something like following;

var recognitions = await Tflite.runModelOnBinary(
  binary: imageToByteListFloat32(image, 224, 127.5, 127.5),// required
  numResults: 6,    // defaults to 5
  threshold: 0.05,  // defaults to 0.1
  asynch: true      // defaults to true
);

Uint8List imageToByteListFloat32(
    img.Image image, int inputSize, double mean, double std) {
  var convertedBytes = Float32List(1 * inputSize * inputSize * 3);
  var buffer = Float32List.view(convertedBytes.buffer);
  int pixelIndex = 0;
  for (var i = 0; i < inputSize; i++) {
    for (var j = 0; j < inputSize; j++) {
      var pixel = image.getPixel(j, i);
      buffer[pixelIndex++] = (img.getRed(pixel) - mean) / std;
      buffer[pixelIndex++] = (img.getGreen(pixel) - mean) / std;
      buffer[pixelIndex++] = (img.getBlue(pixel) - mean) / std;
    }
  }
  return convertedBytes.buffer.asUint8List();
}

The model has 2023 classes. Label map. And I suggest you to use tflite_flutter package instead of flutter_tflite since it is more actively contributed.