From 623d18b621386f667a66b85ad46bf96d29e2cc5c Mon Sep 17 00:00:00 2001 From: Kristijan Date: Tue, 17 Dec 2024 12:34:17 +0100 Subject: [PATCH 1/8] Add launch.json file --- .vscode/launch.json | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..62cea89 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Test example APP (Debug Mode)", + "cwd": "example", + "request": "launch", + "type": "dart" + }, + { + "name": "Test example APP (profile mode)", + "cwd": "example", + "request": "launch", + "type": "dart", + "flutterMode": "profile" + }, + { + "name": "Test example APP (release mode)", + "cwd": "example", + "request": "launch", + "type": "dart", + "flutterMode": "release" + } + ] +} From 64aa7043089a853037a8f67136414eaa77347aba Mon Sep 17 00:00:00 2001 From: Kristijan Date: Tue, 17 Dec 2024 12:42:08 +0100 Subject: [PATCH 2/8] Implement `checkTunnelConfiguration` function - Checks tunnel setup and tunnel permissions --- .../wireguard_dart/WireguardDartPlugin.kt | 10 +- darwin/Classes/WireguardDartPlugin.swift | 226 +++++++++++++++--- example/lib/main.dart | 101 +++++--- lib/wireguard_dart.dart | 7 + lib/wireguard_dart_method_channel.dart | 12 + lib/wireguard_dart_platform_interface.dart | 7 + test/wireguard_dart_test.dart | 5 +- 7 files changed, 300 insertions(+), 68 deletions(-) diff --git a/android/src/main/kotlin/network/mysterium/wireguard_dart/WireguardDartPlugin.kt b/android/src/main/kotlin/network/mysterium/wireguard_dart/WireguardDartPlugin.kt index 67c158a..8752db0 100644 --- a/android/src/main/kotlin/network/mysterium/wireguard_dart/WireguardDartPlugin.kt +++ b/android/src/main/kotlin/network/mysterium/wireguard_dart/WireguardDartPlugin.kt @@ -132,7 +132,9 @@ class WireguardDartPlugin : FlutterPlugin, MethodCallHandler, ActivityAware, call.argument("tunnelName").toString(), result ) - + "checkTunnelConfiguration" -> { + checkTunnelConfiguration(result) + } "connect" -> connect(call.argument("cfg").toString(), result) "disconnect" -> disconnect(result) "status" -> status(result) @@ -141,6 +143,12 @@ class WireguardDartPlugin : FlutterPlugin, MethodCallHandler, ActivityAware, } } + private fun checkTunnelConfiguration(result: MethodChannel.Result) { + val intent = GoBackend.VpnService.prepare(this.activity) + havePermission = intent == null + return result.success(havePermission) + } + private fun generateKeyPair(result: Result) { val keyPair = KeyPair() result.success( diff --git a/darwin/Classes/WireguardDartPlugin.swift b/darwin/Classes/WireguardDartPlugin.swift index 906e205..7cf96d3 100644 --- a/darwin/Classes/WireguardDartPlugin.swift +++ b/darwin/Classes/WireguardDartPlugin.swift @@ -1,40 +1,43 @@ +import NetworkExtension +import WireGuardKit +import os + #if os(iOS) -import Flutter -import UIKit + import Flutter + import UIKit #elseif os(macOS) -import Cocoa -import FlutterMacOS + import Cocoa + import FlutterMacOS #else -#error("Unsupported platform") + #error("Unsupported platform") #endif -import NetworkExtension -import os -import WireGuardKit - public class WireguardDartPlugin: NSObject, FlutterPlugin { private var vpnManager: NETunnelProviderManager? - var vpnStatus: NEVPNStatus { vpnManager?.connection.status ?? NEVPNStatus.invalid } public static func register(with registrar: FlutterPluginRegistrar) { #if os(iOS) - let messenger = registrar.messenger() + let messenger = registrar.messenger() #else - let messenger = registrar.messenger + let messenger = registrar.messenger #endif - let channel = FlutterMethodChannel(name: "wireguard_dart", binaryMessenger: messenger) + let channel = FlutterMethodChannel( + name: "wireguard_dart", binaryMessenger: messenger) let instance = WireguardDartPlugin() registrar.addMethodCallDelegate(instance, channel: channel) - let statusChannel = FlutterEventChannel(name: "wireguard_dart/status", binaryMessenger: messenger) + let statusChannel = FlutterEventChannel( + name: "wireguard_dart/status", binaryMessenger: messenger) statusChannel.setStreamHandler(ConnectionStatusObserver()) } - public func handle(_ call: FlutterMethodCall, result: @escaping FlutterResult) { + public func handle( + _ call: FlutterMethodCall, result: @escaping FlutterResult + ) { switch call.method { case "nativeInit": result("") @@ -42,32 +45,47 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { let privateKey = PrivateKey() let privateKeyResponse: [String: Any] = [ "privateKey": privateKey.base64Key, - "publicKey": privateKey.publicKey.base64Key + "publicKey": privateKey.publicKey.base64Key, ] result(privateKeyResponse) case "setupTunnel": Logger.main.debug("handle setupTunnel") - guard let args = call.arguments as? [String: Any], args["bundleId"] != nil else { - result(nativeFlutterError(message: "required argument: 'bundleId'")) + guard let args = call.arguments as? [String: Any], + args["bundleId"] != nil + else { + result( + nativeFlutterError(message: "required argument: 'bundleId'") + ) return } - guard let bundleId = args["bundleId"] as? String, !bundleId.isEmpty else { - result(nativeFlutterError(message: "required argument: 'bundleId'")) + guard let bundleId = args["bundleId"] as? String, !bundleId.isEmpty + else { + result( + nativeFlutterError(message: "required argument: 'bundleId'") + ) return } - guard let tunnelName = args["tunnelName"] as? String, !tunnelName.isEmpty else { - result(nativeFlutterError(message: "required argument: 'tunnelName'")) + guard let tunnelName = args["tunnelName"] as? String, + !tunnelName.isEmpty + else { + result( + nativeFlutterError( + message: "required argument: 'tunnelName'")) return } - Logger.main.debug("Tunnel bundle ID: \(bundleId), name: \(tunnelName)") + Logger.main.debug( + "Tunnel bundle ID: \(bundleId), name: \(tunnelName)") Task { do { - vpnManager = try await setupProviderManager(bundleId: bundleId, tunnelName: tunnelName) + vpnManager = try await setupProviderManager( + bundleId: bundleId, tunnelName: tunnelName) Logger.main.debug("Tunnel setup OK") result("") } catch { Logger.main.error("Tunnel setup ERROR: \(error)") - result(nativeFlutterError(message: "could not setup VPN tunnel: \(error)")) + result( + nativeFlutterError( + message: "could not setup VPN tunnel: \(error)")) return } } @@ -75,7 +93,8 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { Logger.main.debug("handle connect") let cfg: String if let args = call.arguments as? [String: Any], - let argCfg = args["cfg"] as? String { + let argCfg = args["cfg"] as? String + { cfg = argCfg } else { Logger.main.error("Required argument 'cfg' not provided") @@ -83,27 +102,101 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { return } guard let mgr = vpnManager else { - Logger.main.error("Tunnel not initialized, missing 'vpnManager'") - result(nativeFlutterError(message: "tunnel not initialized, missing 'vpnManager'")) + Logger.main.error( + "Tunnel not initialized, missing 'vpnManager'") + result( + nativeFlutterError( + message: "tunnel not initialized, missing 'vpnManager'") + ) return } Logger.main.debug("Connection configuration: \(cfg)") Task { do { + mgr.isEnabled = true + + try await mgr.saveToPreferences() + try await mgr.loadFromPreferences() try mgr.connection.startVPNTunnel(options: [ "cfg": cfg as NSObject ]) Logger.main.debug("Start VPN tunnel OK") result("") + } catch let error as NEVPNError { + switch error.code { + case .configurationInvalid: + Logger.main.error( + "Start VPN tunnel ERROR: Configuration is invalid") + result( + nativeFlutterError( + message: + "could not start VPN tunnel: Configuration is invalid" + )) + case .configurationDisabled: + Logger.main.error( + "Start VPN tunnel ERROR: Configuration is disabled") + result( + nativeFlutterError( + message: + "could not start VPN tunnel: Configuration is disabled" + )) + case .connectionFailed: + Logger.main.error( + "Start VPN tunnel ERROR: Connection failed") + result( + nativeFlutterError( + message: + "could not start VPN tunnel: Connection failed" + )) + case .configurationStale: + Logger.main.error( + "Start VPN tunnel ERROR: Configuration is stale") + result( + nativeFlutterError( + message: + "could not start VPN tunnel: Configuration is stale" + )) + case .configurationReadWriteFailed: + Logger.main.error( + "Start VPN tunnel ERROR: Configuration read/write failed" + ) + result( + nativeFlutterError( + message: + "could not start VPN tunnel: Configuration read/write failed" + )) + case .configurationUnknown: + Logger.main.error( + "Start VPN tunnel ERROR: Configuration unknown") + result( + nativeFlutterError( + message: + "could not start VPN tunnel: Configuration unknown" + )) + @unknown default: + Logger.main.error( + "Start VPN tunnel ERROR: Unknown error") + result( + nativeFlutterError( + message: + "could not start VPN tunnel: Unknown error") + ) + } } catch { Logger.main.error("Start VPN tunnel ERROR: \(error)") - result(nativeFlutterError(message: "could not start VPN tunnel: \(error)")) + result( + nativeFlutterError( + message: "could not start VPN tunnel: \(error)")) } } case "disconnect": guard let mgr = vpnManager else { - Logger.main.error("Tunnel not initialized, missing 'vpnManager'") - result(nativeFlutterError(message: "tunnel not initialized, missing 'vpnManager'")) + Logger.main.error( + "Tunnel not initialized, missing 'vpnManager'") + result( + nativeFlutterError( + message: "tunnel not initialized, missing 'vpnManager'") + ) return } Task { @@ -112,30 +205,62 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { result("") } case "status": - guard vpnManager != nil else { - Logger.main.error("Tunnel not initialized, missing 'vpnManager'") - result(nativeFlutterError(message: "tunnel not initialized, missing 'vpnManager'")) + if vpnManager != nil { + Task { + result( + ConnectionStatus.fromNEVPNStatus(status: vpnStatus) + .string()) + } + } else { + result(ConnectionStatus.unknown.string()) + } + case "checkTunnelConfiguration": + guard let args = call.arguments as? [String: Any], + let bundleId = args["bundleId"] as? String, !bundleId.isEmpty + else { + result( + nativeFlutterError(message: "required argument: 'bundleId'") + ) return } - Task { - result(ConnectionStatus.fromNEVPNStatus(status: vpnStatus).string()) + guard let args = call.arguments as? [String: Any], + let tunnelName = args["tunnelName"] as? String, !tunnelName.isEmpty + else { + result( + nativeFlutterError(message: "required argument: 'tunnelName'") + ) + return } + checkTunnelConfiguration(bundleId: bundleId, tunnelName: tunnelName) { manager in + if let vpnManager = manager { + self.vpnManager = vpnManager + Logger.main.debug("Tunnel is set up and existing") + result(true) + } else { + Logger.main.debug("Tunnel is not set up") + result(false) + } +} + default: result(FlutterMethodNotImplemented) } } - func setupProviderManager(bundleId: String, tunnelName: String) async throws -> NETunnelProviderManager { + func setupProviderManager(bundleId: String, tunnelName: String) async throws + -> NETunnelProviderManager + { let mgrs = try await NETunnelProviderManager.loadAllFromPreferences() let existingMgr = mgrs.first(where: { - ($0.protocolConfiguration as? NETunnelProviderProtocol)?.providerBundleIdentifier == bundleId + ($0.protocolConfiguration as? NETunnelProviderProtocol)? + .providerBundleIdentifier == bundleId }) let mgr = existingMgr ?? NETunnelProviderManager() mgr.localizedDescription = tunnelName let proto = NETunnelProviderProtocol() proto.providerBundleIdentifier = bundleId - proto.serverAddress = "" // must be non-null + proto.serverAddress = "" // must be non-null mgr.protocolConfiguration = proto mgr.isEnabled = true @@ -144,4 +269,27 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { return mgr } + +func isVpnManagerConfigured(bundleId: String, tunnelName: String) async throws -> NETunnelProviderManager? { + // Load all managers from preferences + let mgrs = try await NETunnelProviderManager.loadAllFromPreferences() + if let existingMgr = mgrs.first(where: { + ($0.protocolConfiguration as? NETunnelProviderProtocol)?.providerBundleIdentifier == bundleId + }) { + return existingMgr + } + return nil +} + + func checkTunnelConfiguration(bundleId: String, tunnelName: String, result: @escaping (NETunnelProviderManager?) -> Void) { + Task { + do { + let mgr = try await isVpnManagerConfigured(bundleId: bundleId, tunnelName: tunnelName) + result(mgr) + } catch { + Logger.main.error("Error checking tunnel configuration: \(error)") + result(nil) + } + } +} } diff --git a/example/lib/main.dart b/example/lib/main.dart index 53253a1..d8bca0e 100644 --- a/example/lib/main.dart +++ b/example/lib/main.dart @@ -5,6 +5,7 @@ import 'dart:isolate'; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import 'package:wireguard_dart/connection_status.dart'; +import 'package:wireguard_dart/key_pair.dart'; import 'package:wireguard_dart/wireguard_dart.dart'; const tunBundleId = "network.mysterium.wireguardDartExample.tun"; @@ -42,6 +43,9 @@ class _MyAppState extends State { final _wireguardDartPlugin = WireguardDart(); ConnectionStatus _status = ConnectionStatus.unknown; late Stream _statusStream; + bool? _checkTunnelConfiguration; + bool? _isTunnelSetup; + KeyPair? _keyPair; @override void initState() { @@ -75,7 +79,10 @@ class _MyAppState extends State { void generateKey() async { try { var keyPair = await _wireguardDartPlugin.generateKeyPair(); - debugPrint('Generated key pair: $keyPair'); + setState(() { + _keyPair = keyPair; + }); + debugPrint('Generated key pair: $_keyPair'); } catch (e) { developer.log( 'Generated key', @@ -89,11 +96,35 @@ class _MyAppState extends State { Isolate.spawn(nativeInitBackground, [rootIsolateToken]); } + Future checkTunnelConfiguration() async { + try { + final status = await _wireguardDartPlugin.checkTunnelConfiguration( + bundleId: tunBundleId, + tunnelName: "WiregardDart", + ); + setState(() { + _checkTunnelConfiguration = status; + }); + debugPrint("Tunnel configured status: $_checkTunnelConfiguration"); + } catch (e) { + developer.log( + 'Is tunnel configured', + error: e, + ); + } + } + void setupTunnel() async { try { await _wireguardDartPlugin.setupTunnel(bundleId: tunBundleId, tunnelName: "WiregardDart", win32ServiceName: winSvcName); + setState(() { + _isTunnelSetup = true; + }); debugPrint("Setup tunnel success"); } catch (e) { + setState(() { + _isTunnelSetup = false; + }); developer.log( 'Setup tunnel', error: e, @@ -155,10 +186,10 @@ class _MyAppState extends State { TextButton( onPressed: generateKey, style: ButtonStyle( - minimumSize: MaterialStateProperty.all(const Size(100, 50)), - padding: MaterialStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), - backgroundColor: MaterialStateProperty.all(Colors.blueAccent), - overlayColor: MaterialStateProperty.all(Colors.white.withOpacity(0.1))), + minimumSize: WidgetStateProperty.all(const Size(100, 50)), + padding: WidgetStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), + backgroundColor: WidgetStateProperty.all(Colors.blueAccent), + overlayColor: WidgetStateProperty.all(Colors.white.withOpacity(0.1))), child: const Text( 'Generate Key', style: TextStyle(color: Colors.white), @@ -168,23 +199,36 @@ class _MyAppState extends State { TextButton( onPressed: nativeInit, style: ButtonStyle( - minimumSize: MaterialStateProperty.all(const Size(100, 50)), - padding: MaterialStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), - backgroundColor: MaterialStateProperty.all(Colors.blueAccent), - overlayColor: MaterialStateProperty.all(Colors.white.withOpacity(0.1))), + minimumSize: WidgetStateProperty.all(const Size(100, 50)), + padding: WidgetStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), + backgroundColor: WidgetStateProperty.all(Colors.blueAccent), + overlayColor: WidgetStateProperty.all(Colors.white.withOpacity(0.1))), child: const Text( 'Native initialization', style: TextStyle(color: Colors.white), ), ), const SizedBox(height: 20), + TextButton( + onPressed: checkTunnelConfiguration, + style: ButtonStyle( + minimumSize: WidgetStateProperty.all(const Size(100, 50)), + padding: WidgetStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), + backgroundColor: WidgetStateProperty.all(Colors.blueAccent), + overlayColor: WidgetStateProperty.all(Colors.white.withOpacity(0.1))), + child: const Text( + 'Is Tunnel Configured', + style: TextStyle(color: Colors.white), + ), + ), + const SizedBox(height: 20), TextButton( onPressed: setupTunnel, style: ButtonStyle( - minimumSize: MaterialStateProperty.all(const Size(100, 50)), - padding: MaterialStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), - backgroundColor: MaterialStateProperty.all(Colors.blueAccent), - overlayColor: MaterialStateProperty.all(Colors.white.withOpacity(0.1))), + minimumSize: WidgetStateProperty.all(const Size(100, 50)), + padding: WidgetStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), + backgroundColor: WidgetStateProperty.all(Colors.blueAccent), + overlayColor: WidgetStateProperty.all(Colors.white.withOpacity(0.1))), child: const Text( 'Setup Tunnel', style: TextStyle(color: Colors.white), @@ -194,10 +238,10 @@ class _MyAppState extends State { TextButton( onPressed: connect, style: ButtonStyle( - minimumSize: MaterialStateProperty.all(const Size(100, 50)), - padding: MaterialStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), - backgroundColor: MaterialStateProperty.all(Colors.blueAccent), - overlayColor: MaterialStateProperty.all(Colors.white.withOpacity(0.1))), + minimumSize: WidgetStateProperty.all(const Size(100, 50)), + padding: WidgetStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), + backgroundColor: WidgetStateProperty.all(Colors.blueAccent), + overlayColor: WidgetStateProperty.all(Colors.white.withOpacity(0.1))), child: const Text( 'Connect', style: TextStyle(color: Colors.white), @@ -207,10 +251,10 @@ class _MyAppState extends State { TextButton( onPressed: disconnect, style: ButtonStyle( - minimumSize: MaterialStateProperty.all(const Size(100, 50)), - padding: MaterialStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), - backgroundColor: MaterialStateProperty.all(Colors.blueAccent), - overlayColor: MaterialStateProperty.all(Colors.white.withOpacity(0.1))), + minimumSize: WidgetStateProperty.all(const Size(100, 50)), + padding: WidgetStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), + backgroundColor: WidgetStateProperty.all(Colors.blueAccent), + overlayColor: WidgetStateProperty.all(Colors.white.withOpacity(0.1))), child: const Text( 'Disconnect', style: TextStyle(color: Colors.white), @@ -220,27 +264,30 @@ class _MyAppState extends State { TextButton( onPressed: status, style: ButtonStyle( - minimumSize: MaterialStateProperty.all(const Size(100, 50)), - padding: MaterialStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), - backgroundColor: MaterialStateProperty.all(Colors.blueAccent), - overlayColor: MaterialStateProperty.all(Colors.white.withOpacity(0.1))), + minimumSize: WidgetStateProperty.all(const Size(100, 50)), + padding: WidgetStateProperty.all(const EdgeInsets.fromLTRB(20, 15, 20, 15)), + backgroundColor: WidgetStateProperty.all(Colors.blueAccent), + overlayColor: WidgetStateProperty.all(Colors.white.withOpacity(0.1))), child: const Text( 'Query status', style: TextStyle(color: Colors.white), ), ), const SizedBox(height: 20), - Text(_status.name), + Text("Query tunnel status: ${_status.name}"), StreamBuilder( initialData: ConnectionStatus.unknown, stream: _statusStream, builder: (BuildContext context, AsyncSnapshot snapshot) { // Check if the snapshot has data and is a map containing the 'status' key if (snapshot.hasData) { - return Text(snapshot.data!.name); + return Text("Tunnel stream status: ${snapshot.data!.name}"); } return const CircularProgressIndicator(); }), + Text('Tunnel configured: $_checkTunnelConfiguration'), + Text('Tunnel setup: $_isTunnelSetup'), + Text('Key pair:\n Public key:${_keyPair?.publicKey}\n Private key:${_keyPair?.privateKey}'), ], ), ), diff --git a/lib/wireguard_dart.dart b/lib/wireguard_dart.dart index 62b7454..89ea4cb 100644 --- a/lib/wireguard_dart.dart +++ b/lib/wireguard_dart.dart @@ -31,4 +31,11 @@ class WireguardDart { Stream statusStream() { return WireguardDartPlatform.instance.statusStream(); } + + Future checkTunnelConfiguration({required String bundleId, required String tunnelName}) { + return WireguardDartPlatform.instance.checkTunnelConfiguration( + bundleId: bundleId, + tunnelName: tunnelName, + ); + } } diff --git a/lib/wireguard_dart_method_channel.dart b/lib/wireguard_dart_method_channel.dart index b22cf91..b71d4cf 100644 --- a/lib/wireguard_dart_method_channel.dart +++ b/lib/wireguard_dart_method_channel.dart @@ -54,4 +54,16 @@ class MethodChannelWireguardDart extends WireguardDartPlatform { Stream statusStream() { return statusChannel.receiveBroadcastStream().distinct().map((val) => ConnectionStatus.fromString(val)); } + + @override + Future checkTunnelConfiguration({ + required String bundleId, + required String tunnelName, + }) async { + var result = await methodChannel.invokeMethod('checkTunnelConfiguration', { + 'bundleId': bundleId, + 'tunnelName': tunnelName, + }); + return result as bool; + } } diff --git a/lib/wireguard_dart_platform_interface.dart b/lib/wireguard_dart_platform_interface.dart index 01055d8..d3e9fad 100644 --- a/lib/wireguard_dart_platform_interface.dart +++ b/lib/wireguard_dart_platform_interface.dart @@ -52,4 +52,11 @@ abstract class WireguardDartPlatform extends PlatformInterface { Stream statusStream() { throw UnimplementedError('statusStream() has not been implemented'); } + + Future checkTunnelConfiguration({ + required String bundleId, + required String tunnelName, + }) { + throw UnimplementedError('checkTunnelConfiguration() has not been implemented'); + } } diff --git a/test/wireguard_dart_test.dart b/test/wireguard_dart_test.dart index 8e3436e..904ce1a 100644 --- a/test/wireguard_dart_test.dart +++ b/test/wireguard_dart_test.dart @@ -2,9 +2,9 @@ import 'package:flutter_test/flutter_test.dart'; import 'package:plugin_platform_interface/plugin_platform_interface.dart'; import 'package:wireguard_dart/connection_status.dart'; import 'package:wireguard_dart/key_pair.dart'; +import 'package:wireguard_dart/wireguard_dart.dart'; import 'package:wireguard_dart/wireguard_dart_method_channel.dart'; import 'package:wireguard_dart/wireguard_dart_platform_interface.dart'; -import 'package:wireguard_dart/wireguard_dart.dart'; class MockWireguardDartPlatform with MockPlatformInterfaceMixin implements WireguardDartPlatform { @override @@ -31,6 +31,9 @@ class MockWireguardDartPlatform with MockPlatformInterfaceMixin implements Wireg Stream statusStream() { return Stream.value(ConnectionStatus.disconnected); } + + @override + Future checkTunnelConfiguration({required String bundleId, required String tunnelName}) => Future.value(true); } void main() { From 0d87039566bb18f323ddbe0b637fb797a1505c62 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrikjeski Date: Tue, 17 Dec 2024 13:13:44 +0100 Subject: [PATCH 3/8] Add checkTunnelConfiguration for windows --- windows/wireguard_dart_plugin.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/windows/wireguard_dart_plugin.cpp b/windows/wireguard_dart_plugin.cpp index 705a96c..6f74e0e 100644 --- a/windows/wireguard_dart_plugin.cpp +++ b/windows/wireguard_dart_plugin.cpp @@ -71,6 +71,10 @@ void WireguardDartPlugin::HandleMethodCall(const flutter::MethodCallSuccess(flutter::EncodableValue(true)); + } + if (call.method_name() == "nativeInit") { // Disable packet forwarding that conflicts with WireGuard ServiceControl remoteAccessService = ServiceControl(L"RemoteAccess"); From 9d88ab7bf4c75b4a80e48b10a0358818d1da29de Mon Sep 17 00:00:00 2001 From: Kristijan Mitrikjeski Date: Tue, 17 Dec 2024 13:53:18 +0100 Subject: [PATCH 4/8] Improve connection status observer impl --- windows/connection_status_observer.cpp | 39 +++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/windows/connection_status_observer.cpp b/windows/connection_status_observer.cpp index a2ce11f..c07a6ef 100644 --- a/windows/connection_status_observer.cpp +++ b/windows/connection_status_observer.cpp @@ -16,50 +16,63 @@ void ConnectionStatusObserver::StartObserving(std::wstring service_name) { if (m_running.load() == true) { return; } - if (service_name.size() > 0) { + + if (!service_name.empty()) { m_service_name = service_name; } + + if (m_service_name.empty()) { + std::cerr << "Service name is empty" << std::endl; + return; + } + SC_HANDLE service_manager = OpenSCManager(NULL, NULL, SC_MANAGER_ALL_ACCESS); if (service_manager == NULL) { + std::cerr << "Failed to open service manager: " << GetLastError() << std::endl; return; } SC_HANDLE service = OpenService(service_manager, &m_service_name[0], SERVICE_QUERY_STATUS | SERVICE_INTERROGATE); - if (service != NULL) { - m_running.store(true); - watch_thread = std::thread(&ConnectionStatusObserver::StartObservingThreadProc, this, service_manager, service); - } else { + if (service == NULL) { + std::cerr << "Failed to open service: " << GetLastError() << std::endl; CloseServiceHandle(service_manager); - m_running.store(false); return; } -} -void ConnectionStatusObserver::StopObserving() { m_watch_thread_stop.store(true); } + m_running.store(true); + m_watch_thread_stop.store(false); + watch_thread = std::thread(&ConnectionStatusObserver::StartObservingThreadProc, this, service_manager, service); +} -void ConnectionStatusObserver::Shutdown() { +void ConnectionStatusObserver::StopObserving() { m_watch_thread_stop.store(true); if (watch_thread.joinable()) { watch_thread.join(); } } +void ConnectionStatusObserver::Shutdown() { + StopObserving(); + m_running.store(false); +} + void ConnectionStatusObserver::StartObservingThreadProc(SC_HANDLE service_manager, SC_HANDLE service) { SERVICE_NOTIFY s_notify = {0}; s_notify.dwVersion = SERVICE_NOTIFY_STATUS_CHANGE; s_notify.pfnNotifyCallback = &ServiceNotifyCallback; s_notify.pContext = static_cast(this); - while (m_watch_thread_stop.load() == false) { + while (!m_watch_thread_stop.load()) { if (NotifyServiceStatusChange(service, SERVICE_NOTIFY_RUNNING | SERVICE_NOTIFY_START_PENDING | SERVICE_NOTIFY_STOPPED | SERVICE_NOTIFY_STOP_PENDING, &s_notify) == ERROR_SUCCESS) { ::SleepEx(INFINITE, true); } else { - CloseServiceHandle(service); - CloseServiceHandle(service_manager); + std::cerr << "Failed to notify service status change: " << GetLastError() << std::endl; break; } } + CloseServiceHandle(service); + CloseServiceHandle(service_manager); m_running.store(false); } @@ -68,6 +81,7 @@ void CALLBACK ConnectionStatusObserver::ServiceNotifyCallback(void* ptr) { ConnectionStatusObserver* instance = static_cast(serviceNotify->pContext); if (!instance || serviceNotify->dwNotificationStatus != ERROR_SUCCESS) { + std::cerr << "Service notification failed: " << serviceNotify->dwNotificationStatus << std::endl; return; } @@ -82,7 +96,6 @@ void CALLBACK ConnectionStatusObserver::ServiceNotifyCallback(void* ptr) { std::unique_ptr> ConnectionStatusObserver::OnListenInternal( const flutter::EncodableValue* arguments, std::unique_ptr>&& events) { sink_ = std::move(events); - // sink_->Success(flutter::EncodableValue(ConnectionStatusToString(ConnectionStatus::disconnected))); return nullptr; } From 4db828b02f6a1cdf0d27918c5c581e7d228062aa Mon Sep 17 00:00:00 2001 From: Kristijan Mitrikjeski Date: Tue, 17 Dec 2024 13:53:29 +0100 Subject: [PATCH 5/8] Check for tunnel service init --- windows/wireguard_dart_plugin.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/windows/wireguard_dart_plugin.cpp b/windows/wireguard_dart_plugin.cpp index 6f74e0e..abe38b8 100644 --- a/windows/wireguard_dart_plugin.cpp +++ b/windows/wireguard_dart_plugin.cpp @@ -72,7 +72,12 @@ void WireguardDartPlugin::HandleMethodCall(const flutter::MethodCallSuccess(flutter::EncodableValue(true)); + auto tunnel_service = this->tunnel_service_.get(); + if (tunnel_service == nullptr) { + result->Success(flutter::EncodableValue(false)); + } else { + result->Success(flutter::EncodableValue(true)); + } } if (call.method_name() == "nativeInit") { From b79d5ba7a18e664d29156bf5ca3758b3973b2594 Mon Sep 17 00:00:00 2001 From: Kristijan Mitrikjeski Date: Tue, 17 Dec 2024 14:36:31 +0100 Subject: [PATCH 6/8] Refactor method --- windows/wireguard_dart_plugin.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/windows/wireguard_dart_plugin.cpp b/windows/wireguard_dart_plugin.cpp index abe38b8..90d74b7 100644 --- a/windows/wireguard_dart_plugin.cpp +++ b/windows/wireguard_dart_plugin.cpp @@ -73,11 +73,8 @@ void WireguardDartPlugin::HandleMethodCall(const flutter::MethodCalltunnel_service_.get(); - if (tunnel_service == nullptr) { - result->Success(flutter::EncodableValue(false)); - } else { - result->Success(flutter::EncodableValue(true)); - } + result->Success(flutter::EncodableValue(tunnel_service != nullptr)); + return; } if (call.method_name() == "nativeInit") { From 0dc4967035c712f5902801197f8ab737fb47023a Mon Sep 17 00:00:00 2001 From: Kristijan Date: Tue, 17 Dec 2024 15:48:47 +0100 Subject: [PATCH 7/8] Switch var to final --- lib/wireguard_dart_method_channel.dart | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/wireguard_dart_method_channel.dart b/lib/wireguard_dart_method_channel.dart index b71d4cf..80d4d06 100644 --- a/lib/wireguard_dart_method_channel.dart +++ b/lib/wireguard_dart_method_channel.dart @@ -12,7 +12,7 @@ class MethodChannelWireguardDart extends WireguardDartPlatform { @override Future generateKeyPair() async { - var result = await methodChannel.invokeMapMethod('generateKeyPair') ?? {}; + final result = await methodChannel.invokeMapMethod('generateKeyPair') ?? {}; if (!result.containsKey('publicKey') || !result.containsKey('privateKey')) { throw StateError('Could not generate keypair'); } @@ -26,7 +26,7 @@ class MethodChannelWireguardDart extends WireguardDartPlatform { @override Future setupTunnel({required String bundleId, required String tunnelName, String? win32ServiceName}) async { - var args = { + final args = { 'bundleId': bundleId, 'tunnelName': tunnelName, if (win32ServiceName != null) 'win32ServiceName': win32ServiceName, @@ -46,7 +46,7 @@ class MethodChannelWireguardDart extends WireguardDartPlatform { @override Future status() async { - var result = await methodChannel.invokeMethod('status'); + final result = await methodChannel.invokeMethod('status'); return ConnectionStatus.fromString(result ?? ""); } @@ -60,7 +60,7 @@ class MethodChannelWireguardDart extends WireguardDartPlatform { required String bundleId, required String tunnelName, }) async { - var result = await methodChannel.invokeMethod('checkTunnelConfiguration', { + final result = await methodChannel.invokeMethod('checkTunnelConfiguration', { 'bundleId': bundleId, 'tunnelName': tunnelName, }); From f7dfc8a97f9dd5f6bfd9a436e3b3be652abf6cbe Mon Sep 17 00:00:00 2001 From: Kristijan Date: Wed, 18 Dec 2024 12:48:09 +0100 Subject: [PATCH 8/8] Re-enable the tunnel if it's disabled before connecting --- darwin/Classes/WireguardDartPlugin.swift | 82 ++++++++++++++---------- 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/darwin/Classes/WireguardDartPlugin.swift b/darwin/Classes/WireguardDartPlugin.swift index 7cf96d3..9d84280 100644 --- a/darwin/Classes/WireguardDartPlugin.swift +++ b/darwin/Classes/WireguardDartPlugin.swift @@ -113,10 +113,13 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { Logger.main.debug("Connection configuration: \(cfg)") Task { do { - mgr.isEnabled = true + if !mgr.isEnabled { + mgr.isEnabled = true + try await mgr.saveToPreferences() + try await mgr.loadFromPreferences() + + } - try await mgr.saveToPreferences() - try await mgr.loadFromPreferences() try mgr.connection.startVPNTunnel(options: [ "cfg": cfg as NSObject ]) @@ -216,7 +219,7 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { } case "checkTunnelConfiguration": guard let args = call.arguments as? [String: Any], - let bundleId = args["bundleId"] as? String, !bundleId.isEmpty + let bundleId = args["bundleId"] as? String, !bundleId.isEmpty else { result( nativeFlutterError(message: "required argument: 'bundleId'") @@ -224,23 +227,26 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { return } guard let args = call.arguments as? [String: Any], - let tunnelName = args["tunnelName"] as? String, !tunnelName.isEmpty + let tunnelName = args["tunnelName"] as? String, + !tunnelName.isEmpty else { result( - nativeFlutterError(message: "required argument: 'tunnelName'") + nativeFlutterError( + message: "required argument: 'tunnelName'") ) return } - checkTunnelConfiguration(bundleId: bundleId, tunnelName: tunnelName) { manager in - if let vpnManager = manager { - self.vpnManager = vpnManager - Logger.main.debug("Tunnel is set up and existing") - result(true) - } else { - Logger.main.debug("Tunnel is not set up") - result(false) - } -} + checkTunnelConfiguration(bundleId: bundleId, tunnelName: tunnelName) + { manager in + if let vpnManager = manager { + self.vpnManager = vpnManager + Logger.main.debug("Tunnel is set up and existing") + result(true) + } else { + Logger.main.debug("Tunnel is not set up") + result(false) + } + } default: result(FlutterMethodNotImplemented) @@ -270,26 +276,34 @@ public class WireguardDartPlugin: NSObject, FlutterPlugin { return mgr } -func isVpnManagerConfigured(bundleId: String, tunnelName: String) async throws -> NETunnelProviderManager? { - // Load all managers from preferences - let mgrs = try await NETunnelProviderManager.loadAllFromPreferences() - if let existingMgr = mgrs.first(where: { - ($0.protocolConfiguration as? NETunnelProviderProtocol)?.providerBundleIdentifier == bundleId - }) { - return existingMgr + func isVpnManagerConfigured(bundleId: String, tunnelName: String) + async throws -> NETunnelProviderManager? + { + // Load all managers from preferences + let mgrs = try await NETunnelProviderManager.loadAllFromPreferences() + if let existingMgr = mgrs.first(where: { + ($0.protocolConfiguration as? NETunnelProviderProtocol)? + .providerBundleIdentifier == bundleId + }) { + return existingMgr + } + return nil } - return nil -} - func checkTunnelConfiguration(bundleId: String, tunnelName: String, result: @escaping (NETunnelProviderManager?) -> Void) { - Task { - do { - let mgr = try await isVpnManagerConfigured(bundleId: bundleId, tunnelName: tunnelName) - result(mgr) - } catch { - Logger.main.error("Error checking tunnel configuration: \(error)") - result(nil) + func checkTunnelConfiguration( + bundleId: String, tunnelName: String, + result: @escaping (NETunnelProviderManager?) -> Void + ) { + Task { + do { + let mgr = try await isVpnManagerConfigured( + bundleId: bundleId, tunnelName: tunnelName) + result(mgr) + } catch { + Logger.main.error( + "Error checking tunnel configuration: \(error)") + result(nil) + } } } } -}