Skip to content
This repository has been archived by the owner on Feb 14, 2024. It is now read-only.

Commit

Permalink
Flutter demo can start fresh when training
Browse files Browse the repository at this point in the history
  • Loading branch information
SichangHe committed Jul 27, 2023
1 parent 11120fe commit e03d941
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class MainActivity : FlutterActivity() {
val partitionId = call.argument<Int>("partitionId")!!
val host = call.argument<String>("host")!!
val backendUrl = call.argument<String>("backendUrl")!!
connect(partitionId, host, backendUrl, result)
val startFresh = call.argument<Boolean>("startFresh")!!
connect(partitionId, host, backendUrl, startFresh, result)
}

"train" -> train(result)
Expand All @@ -65,12 +66,17 @@ class MainActivity : FlutterActivity() {
}
}

suspend fun connect(partitionId: Int, host: String, backendUrl: String, result: Result) {
suspend fun connect(
partitionId: Int,
host: String,
backendUrl: String,
startFresh: Boolean,
result: Result
) {
train = Train(this, backendUrl, sampleSpec())
train.enableTelemetry(deviceId(this))
val modelFile = train.prepareModel(DATA_TYPE)
// TODO: freshStartCheckbox
val serverData = train.getServerInfo()
val serverData = train.getServerInfo(startFresh)
if (serverData.port == null) {
return result.error(
TAG, "Flower server port not available", "status ${serverData.status}"
Expand Down
17 changes: 14 additions & 3 deletions fed_kit_client/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class _MyAppState extends State<MyApp> {
final _channel = PlatformChannel();
var canConnect = true;
var canTrain = false;
var startFresh = false;

@override
void initState() {
Expand Down Expand Up @@ -100,7 +101,8 @@ class _MyAppState extends State<MyApp> {
'Connecting with Partition ID: $partitionId, Server IP: $host, Port: $backendPort');

try {
final serverPort = await _channel.connect(partitionId, host, backendUrl);
final serverPort = await _channel.connect(partitionId, host, backendUrl,
startFresh: startFresh);
canTrain = true;
return appendLog(
'Connected to Flower server on port $serverPort and loaded data set.');
Expand Down Expand Up @@ -163,6 +165,16 @@ class _MyAppState extends State<MyApp> {
),
keyboardType: TextInputType.number,
),
Row(
children: [
Checkbox(
value: startFresh,
onChanged: (checked) {
setState(() => startFresh = checked!);
}),
const Text('Start Fresh')
],
),
Row(mainAxisAlignment: MainAxisAlignment.center, children: [
ElevatedButton(
onPressed: canConnect ? connect : null,
Expand All @@ -173,7 +185,6 @@ class _MyAppState extends State<MyApp> {
child: const Text('Train'),
),
]),
const Text('Activity Log'),
Expanded(
child: ListView.builder(
controller: scrollController,
Expand All @@ -183,7 +194,7 @@ class _MyAppState extends State<MyApp> {
itemCount: logs.length,
itemBuilder: (context, index) => logs[logs.length - index - 1],
),
)
),
];

return MaterialApp(
Expand Down
6 changes: 4 additions & 2 deletions fed_kit_client/lib/platform_channel.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ class PlatformChannel {
return await methodChannel.invokeMethod<String>('getPlatformVersion');
}

Future<int?> connect(int partitionId, Uri host, Uri backendUrl) async {
Future<int?> connect(int partitionId, Uri host, Uri backendUrl,
{bool startFresh = false}) async {
return await methodChannel.invokeMethod<int>('connect', {
'partitionId': partitionId,
'host': host.host,
'backendUrl': backendUrl.toString()
'backendUrl': backendUrl.toString(),
'startFresh': startFresh,
});
}

Expand Down

0 comments on commit e03d941

Please sign in to comment.