Skip to content

Commit ea2ae94

Browse files
committed
save work
1 parent 0610987 commit ea2ae94

File tree

14 files changed

+134
-29
lines changed

14 files changed

+134
-29
lines changed

.idea/libraries/Flutter_Plugins.xml

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/vcs.xml

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/workspace.xml

+26-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

android/src/main/java/io/fynn/torch_mobile/TorchMobilePlugin.java

+52-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package io.fynn.torch_mobile;
22

3+
import android.util.Log;
4+
35
import androidx.annotation.NonNull;
46
import io.flutter.embedding.engine.plugins.FlutterPlugin;
57
import io.flutter.plugin.common.MethodCall;
@@ -8,35 +10,71 @@
810
import io.flutter.plugin.common.MethodChannel.Result;
911
import io.flutter.plugin.common.PluginRegistry.Registrar;
1012

13+
import org.pytorch.IValue;
14+
import org.pytorch.Module;
15+
import org.pytorch.Tensor;
16+
import org.pytorch.torchvision.TensorImageUtils;
17+
18+
import java.util.ArrayList;
19+
import java.util.Arrays;
20+
import java.util.List;
21+
1122
/** TorchMobilePlugin */
1223
public class TorchMobilePlugin implements FlutterPlugin, MethodCallHandler {
24+
25+
ArrayList<Module> modules = new ArrayList<>();
26+
1327
@Override
1428
public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) {
15-
final MethodChannel channel = new MethodChannel(flutterPluginBinding.getFlutterEngine().getDartExecutor(), "torch_mobile");
29+
final MethodChannel channel = new MethodChannel(flutterPluginBinding.getFlutterEngine().getDartExecutor(),
30+
"torch_mobile");
1631
channel.setMethodCallHandler(new TorchMobilePlugin());
1732
}
1833

19-
// This static function is optional and equivalent to onAttachedToEngine. It supports the old
20-
// pre-Flutter-1.12 Android projects. You are encouraged to continue supporting
21-
// plugin registration via this function while apps migrate to use the new Android APIs
22-
// post-flutter-1.12 via https://flutter.dev/go/android-project-migration.
23-
//
24-
// It is encouraged to share logic between onAttachedToEngine and registerWith to keep
25-
// them functionally equivalent. Only one of onAttachedToEngine or registerWith will be called
26-
// depending on the user's project. onAttachedToEngine or registerWith must both be defined
27-
// in the same class.
2834
public static void registerWith(Registrar registrar) {
2935
final MethodChannel channel = new MethodChannel(registrar.messenger(), "torch_mobile");
3036
channel.setMethodCallHandler(new TorchMobilePlugin());
3137
}
3238

3339
@Override
3440
public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) {
35-
if (call.method.equals("loadModel")) {
36-
result.success("test");
37-
} else {
38-
result.notImplemented();
41+
switch (call.method){
42+
case "loadModel":
43+
try {
44+
String absPath = call.argument("absPath");
45+
modules.add(Module.load(absPath));
46+
result.success(modules.size() - 1);
47+
} catch (Exception e) {
48+
String assetPath = call.argument("assetPath");
49+
Log.e("TorchMobile", assetPath + " is not a proper model", e);
50+
}
51+
break;
52+
case "predict":
53+
Module module;
54+
try{
55+
int index = call.argument("index");
56+
module = modules.get(index);
57+
58+
ArrayList<Long> shapeList = call.argument("shape");
59+
Long[] shape = shapeList.toArray(new Long[shapeList.size()]);
60+
Tensor.fromBlob(new int[]{1,2,3,4}, toPrimitives(shape));
61+
62+
}catch(Exception e){
63+
Log.e("TorchMobile", "", e);
64+
}
65+
break;
66+
default:
67+
result.notImplemented();
68+
break;
69+
}
70+
}
71+
72+
public static long[] toPrimitives(Long[] objects){
73+
long[] primitives = new long[objects.length];
74+
for(int i = 0; i < objects.length; i++){
75+
primitives[i] = objects[i];
3976
}
77+
return primitives;
4078
}
4179

4280
@Override

example/android/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ buildscript {
55
}
66

77
dependencies {
8-
classpath 'com.android.tools.build:gradle:3.5.0'
8+
classpath 'com.android.tools.build:gradle:3.6.2'
99
}
1010
}
1111

Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#Fri Jun 23 08:50:38 CEST 2017
1+
#Sat Apr 11 16:19:17 CEST 2020
22
distributionBase=GRADLE_USER_HOME
33
distributionPath=wrapper/dists
44
zipStoreBase=GRADLE_USER_HOME
55
zipStorePath=wrapper/dists
6-
distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.2-all.zip
6+
distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.4-all.zip

example/assets/model.pt

44.7 MB
Binary file not shown.

example/lib/main.dart

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import 'dart:async';
33

44
import 'package:flutter/services.dart';
55
import 'package:torch_mobile/torch_mobile.dart';
6+
import 'package:torch_mobile/model.dart';
67

78
void main() => runApp(MyApp());
89

@@ -17,9 +18,15 @@ class _MyAppState extends State<MyApp> {
1718
super.initState();
1819
}
1920

21+
Future<Model> loadModel(String path) async {
22+
return await TorchMobile.getModel(path);
23+
}
24+
2025
@override
2126
Widget build(BuildContext context) {
22-
TorchMobile.getModel("");
27+
loadModel("assets/model.pt").then((Model model){
28+
model.getPrediction([1,2,3,4], [1,2,2]);
29+
} );
2330
return MaterialApp(
2431
home: Scaffold(
2532
appBar: AppBar(

example/pubspec.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ dev_dependencies:
2525

2626
# The following section is specific to Flutter.
2727
flutter:
28+
assets:
29+
- assets/
2830

2931
# The following line ensures that the Material Icons font is
3032
# included with your application, so that you can use the icons in

lib/model.dart

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ import 'package:flutter/services.dart';
33
class Model {
44
static const MethodChannel _channel = const MethodChannel('torch_mobile');
55

6-
static Future<String> getPrediction() async {
7-
final String prediction = await _channel.invokeMethod('predict');
6+
final int _index;
7+
8+
Model(this._index);
9+
10+
Future<String> getPrediction(List input, List shape) async {
11+
final String prediction = await _channel.invokeMethod('predict', {"index": this._index, "input": input, "shape": shape});
812
return prediction;
913
}
1014

lib/torch_mobile.dart

+23-6
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,37 @@ import 'dart:async';
22
import 'dart:io';
33

44
import 'package:flutter/services.dart';
5+
import 'package:path/path.dart';
56
import 'package:path_provider/path_provider.dart';
67
import 'package:torch_mobile/model.dart';
78

89
class TorchMobile {
910
static const MethodChannel _channel = const MethodChannel('torch_mobile');
1011

11-
static Future<Model> getModel(String path) async{
12-
await _channel.invokeMethod("loadModel", await _getAbsolutePath(path));
13-
return Model();
12+
static Future<Model> getModel(String path) async {
13+
String absPath = await _getAbsolutePath(path);
14+
int index = await _channel
15+
.invokeMethod("loadModel", {"absPath": absPath, "assetPath": path});
16+
return Model(index);
1417
}
1518

16-
static Future<String> _getAbsolutePath(String path) async{
19+
static Future<String> _getAbsolutePath(String path) async {
1720
Directory dir = await getApplicationDocumentsDirectory();
18-
print(dir.listSync());
19-
return "";
21+
String dirPath = join(dir.path, path);
22+
ByteData data = await rootBundle.load(path);
23+
//Copy asset to documents directory
24+
List<int> bytes =
25+
data.buffer.asUint8List(data.offsetInBytes, data.lengthInBytes);
26+
27+
//create directory
28+
List split = path.split("/");
29+
for (int i = 0; i < split.length; i++) {
30+
if (i != split.length - 1) {
31+
await Directory(join(dir.path, split[i])).create();
32+
}
33+
}
34+
await File(dirPath).writeAsBytes(bytes);
35+
36+
return dirPath;
2037
}
2138
}

pubspec.lock

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ packages:
8989
source: hosted
9090
version: "1.1.8"
9191
path:
92-
dependency: transitive
92+
dependency: "direct main"
9393
description:
9494
name: path
9595
url: "https://pub.dartlang.org"

pubspec.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dependencies:
1111
flutter:
1212
sdk: flutter
1313
path_provider: ^1.6.5
14+
path: ^1.6.4
1415

1516
dev_dependencies:
1617
flutter_test:

0 commit comments

Comments
 (0)