Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add complete server/client TLS support #158

Merged
merged 18 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 193 additions & 59 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import android.annotation.SuppressLint;
import android.content.Context;

import androidx.annotation.NonNull;
import androidx.annotation.RawRes;

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
Expand All @@ -13,14 +16,16 @@
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509ExtendedKeyManager;
import javax.net.ssl.X509TrustManager;

import androidx.annotation.NonNull;
import androidx.annotation.RawRes;

final class SSLCertificateHelper {
/**
Expand All @@ -34,6 +39,23 @@ static SSLSocketFactory createBlindSocketFactory() throws GeneralSecurityExcepti
return ctx.getSocketFactory();
}

static SSLServerSocketFactory createServerSocketFactory(Context context, @NonNull final String keyStoreResourceUri) throws GeneralSecurityException, IOException {
char[] password = "".toCharArray();

InputStream keyStoreInput = getRawResourceStream(context, keyStoreResourceUri);
KeyStore keyStore = KeyStore.getInstance("PKCS12");
keyStore.load(keyStoreInput, password);
keyStoreInput.close();

KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance("X509");
keyManagerFactory.init(keyStore, password);

SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(keyManagerFactory.getKeyManagers(), new TrustManager[]{new BlindTrustManager()}, null);

return sslContext.getServerSocketFactory();
}

/**
* Creates an SSLSocketFactory instance for use with the CA provided in the resource file.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.asterinet.react.tcpsocket;

import android.util.Base64;
import android.util.Log;

import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.ReactContext;
Expand All @@ -24,6 +25,14 @@ public TcpEventListener(final ReactContext reactContext) {
}

public void onConnection(int serverId, int clientId, Socket socket) {
onSocketConnection("connection", serverId, clientId, socket);
}

public void onSecureConnection(int serverId, int clientId, Socket socket) {
onSocketConnection("secureConnection", serverId, clientId, socket);
}

private void onSocketConnection(String connectionType, int serverId, int clientId, Socket socket) {
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", serverId);

Expand All @@ -42,7 +51,7 @@ public void onConnection(int serverId, int clientId, Socket socket) {
infoParams.putMap("connection", connectionParams);
eventParams.putMap("info", infoParams);

sendEvent("connection", eventParams);
sendEvent(connectionType, eventParams);
}

public void onConnect(int id, TcpSocketClient client) {
Expand Down Expand Up @@ -83,7 +92,12 @@ public void onData(int id, byte[] data) {
sendEvent("data", eventParams);
}

public void onWritten(int id, int msgId, @Nullable String error) {
public void onWritten(int id, int msgId, @Nullable Exception e) {
String error = null;
if (e != null) {
Log.e(TcpSocketModule.TAG, "Exception on socket " + id, e);
error = e.getMessage();
}
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", id);
eventParams.putInt("msgId", msgId);
Expand All @@ -92,18 +106,20 @@ public void onWritten(int id, int msgId, @Nullable String error) {
sendEvent("written", eventParams);
}

public void onClose(int id, String error) {
if (error != null) {
onError(id, error);
public void onClose(int id, Exception e) {
if (e != null) {
onError(id, e);
}
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", id);
eventParams.putBoolean("hadError", error != null);
eventParams.putBoolean("hadError", e != null);

sendEvent("close", eventParams);
}

public void onError(int id, String error) {
public void onError(int id, Exception e) {
Log.e(TcpSocketModule.TAG, "Exception on socket " + id, e);
String error = e.getMessage();
WritableMap eventParams = Arguments.createMap();
eventParams.putInt("id", id);
eventParams.putString("error", error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import javax.net.SocketFactory;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

Expand All @@ -25,6 +24,7 @@ class TcpSocketClient extends TcpSocket {
private final TcpEventListener receiverListener;
private TcpReceiverTask receiverTask;
private Socket socket;
private boolean closed = true;

TcpSocketClient(TcpEventListener receiverListener, Integer id, Socket socket) {
super(id);
Expand All @@ -38,20 +38,12 @@ public Socket getSocket() {
return socket;
}

public void connect(Context context, String address, final Integer port, ReadableMap options, Network network) throws IOException, GeneralSecurityException {
public void connect(Context context, String address, final Integer port, ReadableMap options, Network network, ReadableMap tlsOptions) throws IOException, GeneralSecurityException {
if (socket != null) throw new IOException("Already connected");
final boolean isTls = options.hasKey("tls") && options.getBoolean("tls");
if (isTls) {
SocketFactory sf;
if (options.hasKey("tlsCheckValidity") && !options.getBoolean("tlsCheckValidity")) {
sf = SSLCertificateHelper.createBlindSocketFactory();
} else {
final String customTlsCert = options.hasKey("tlsCert") ? options.getString("tlsCert") : null;
sf = customTlsCert != null ? SSLCertificateHelper.createCustomTrustedSocketFactory(context, customTlsCert) : SSLSocketFactory.getDefault();
}
final SSLSocket sslSocket = (SSLSocket) sf.createSocket();
sslSocket.setUseClientMode(true);
socket = sslSocket;
if (tlsOptions != null) {
SSLSocketFactory ssf = getSSLSocketFactory(context, tlsOptions);
socket = ssf.createSocket();
((SSLSocket) socket).setUseClientMode(true);
} else {
socket = new Socket();
}
Expand All @@ -73,10 +65,30 @@ public void connect(Context context, String address, final Integer port, Readabl
// bind
socket.bind(new InetSocketAddress(localInetAddress, localPort));
socket.connect(new InetSocketAddress(remoteInetAddress, port));
if (isTls) ((SSLSocket) socket).startHandshake();
if (socket instanceof SSLSocket) ((SSLSocket) socket).startHandshake();
startListening();
}

public void startTLS(Context context, ReadableMap tlsOptions) throws IOException, GeneralSecurityException {
if (socket instanceof SSLSocket) return;
SSLSocketFactory ssf = getSSLSocketFactory(context, tlsOptions);
SSLSocket sslSocket = (SSLSocket) ssf.createSocket(socket, socket.getInetAddress().getHostAddress(), socket.getPort(), true);
sslSocket.setUseClientMode(true);
sslSocket.startHandshake();
socket = sslSocket;
}

private SSLSocketFactory getSSLSocketFactory(Context context, ReadableMap tlsOptions) throws GeneralSecurityException, IOException {
SSLSocketFactory ssf;
if (tlsOptions.hasKey("rejectUnauthorized") && !tlsOptions.getBoolean("rejectUnauthorized")) {
ssf = SSLCertificateHelper.createBlindSocketFactory();
} else {
final String customTlsCert = tlsOptions.hasKey("ca") ? tlsOptions.getString("ca") : null;
ssf = customTlsCert != null ? SSLCertificateHelper.createCustomTrustedSocketFactory(context, customTlsCert) : (SSLSocketFactory) SSLSocketFactory.getDefault();
}
return ssf;
}

public void startListening() {
receiverTask = new TcpReceiverTask(this, receiverListener);
listenExecutor.execute(receiverTask);
Expand All @@ -95,8 +107,8 @@ public void run() {
socket.getOutputStream().write(data);
receiverListener.onWritten(getId(), msgId, null);
} catch (IOException e) {
receiverListener.onWritten(getId(), msgId, e.toString());
receiverListener.onError(getId(), e.toString());
receiverListener.onWritten(getId(), msgId, e);
receiverListener.onError(getId(), e);
}
}
});
Expand All @@ -109,12 +121,13 @@ public void destroy() {
try {
// close the socket
if (socket != null && !socket.isClosed()) {
closed = true;
socket.close();
receiverListener.onClose(getId(), null);
socket = null;
}
} catch (IOException e) {
receiverListener.onClose(getId(), e.getMessage());
receiverListener.onClose(getId(), e);
}
}

Expand Down Expand Up @@ -183,8 +196,8 @@ public void run() {
}
}
} catch (IOException | InterruptedException ioe) {
if (receiverListener != null && !socket.isClosed()) {
receiverListener.onError(socketId, ioe.getMessage());
if (receiverListener != null && !socket.isClosed() && !clientSocket.closed) {
receiverListener.onError(socketId, ioe);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import android.annotation.SuppressLint;
import android.content.Context;
import android.net.ConnectivityManager;
import android.net.Network;
import android.net.NetworkCapabilities;
import android.net.NetworkRequest;
import android.util.Base64;
import android.net.Network;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

import com.facebook.react.bridge.ReactApplicationContext;
import com.facebook.react.bridge.ReactContextBaseJavaModule;
Expand All @@ -22,14 +25,12 @@
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;

public class TcpSocketModule extends ReactContextBaseJavaModule {
private static final String TAG = "TcpSockets";
public static final String TAG = "TcpSockets";
private static final int N_THREADS = 2;
private final ReactApplicationContext mReactContext;
private final ConcurrentHashMap<Integer, TcpSocket> socketMap = new ConcurrentHashMap<>();
private final ConcurrentHashMap<Integer, ReadableMap> pendingTLS = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Network> mNetworkMap = new ConcurrentHashMap<>();
private final CurrentNetwork currentNetwork = new CurrentNetwork();
private final ExecutorService executorService = Executors.newFixedThreadPool(N_THREADS);
Expand Down Expand Up @@ -68,7 +69,7 @@ public void connect(@NonNull final Integer cId, @NonNull final String host, @Non
@Override
public void run() {
if (socketMap.get(cId) != null) {
tcpEvtListener.onError(cId, TAG + "createSocket called twice with the same id.");
tcpEvtListener.onError(cId, new Exception("connect() called twice with the same id."));
return;
}
try {
Expand All @@ -78,15 +79,33 @@ public void run() {
selectNetwork(iface, localAddress);
TcpSocketClient client = new TcpSocketClient(tcpEvtListener, cId, null);
socketMap.put(cId, client);
client.connect(mReactContext, host, port, options, currentNetwork.getNetwork());
ReadableMap tlsOptions = pendingTLS.get(cId);
client.connect(mReactContext, host, port, options, currentNetwork.getNetwork(), tlsOptions);
tcpEvtListener.onConnect(cId, client);
} catch (Exception e) {
tcpEvtListener.onError(cId, e.getMessage());
tcpEvtListener.onError(cId, e);
}
}
});
}

@SuppressLint("StaticFieldLeak")
@SuppressWarnings("unused")
@ReactMethod
public void startTLS(final int cId, @NonNull final ReadableMap tlsOptions) {
TcpSocketClient socketClient = (TcpSocketClient) socketMap.get(cId);
// Not yet connected
if (socketClient == null) {
pendingTLS.put(cId, tlsOptions);
} else {
try {
socketClient.startTLS(mReactContext, tlsOptions);
} catch (Exception e) {
tcpEvtListener.onError(cId, e);
}
}
}

@SuppressLint("StaticFieldLeak")
@SuppressWarnings("unused")
@ReactMethod
Expand Down Expand Up @@ -137,11 +156,11 @@ public void listen(final Integer cId, final ReadableMap options) {
@Override
public void run() {
try {
TcpSocketServer server = new TcpSocketServer(socketMap, tcpEvtListener, cId, options);
TcpSocketServer server = new TcpSocketServer(mReactContext, socketMap, tcpEvtListener, cId, options);
socketMap.put(cId, server);
tcpEvtListener.onListen(cId, server);
} catch (Exception uhe) {
tcpEvtListener.onError(cId, uhe.getMessage());
tcpEvtListener.onError(cId, uhe);
}
}
});
Expand All @@ -154,7 +173,7 @@ public void setNoDelay(@NonNull final Integer cId, final boolean noDelay) {
try {
client.setNoDelay(noDelay);
} catch (IOException e) {
tcpEvtListener.onError(cId, e.getMessage());
tcpEvtListener.onError(cId, e);
}
}

Expand All @@ -165,7 +184,7 @@ public void setKeepAlive(@NonNull final Integer cId, final boolean enable, final
try {
client.setKeepAlive(enable, initialDelay);
} catch (IOException e) {
tcpEvtListener.onError(cId, e.getMessage());
tcpEvtListener.onError(cId, e);
}
}

Expand All @@ -182,7 +201,7 @@ public void resume(final int cId) {
TcpSocketClient client = getTcpClient(cId);
client.resume();
}

@SuppressWarnings("unused")
@ReactMethod
public void addListener(String eventName) {
Expand Down Expand Up @@ -260,21 +279,21 @@ private void selectNetwork(@Nullable final String iface, @Nullable final String
private TcpSocketClient getTcpClient(final int id) {
TcpSocket socket = socketMap.get(id);
if (socket == null) {
throw new IllegalArgumentException(TAG + "No socket with id " + id);
throw new IllegalArgumentException("No socket with id " + id);
}
if (!(socket instanceof TcpSocketClient)) {
throw new IllegalArgumentException(TAG + "Socket with id " + id + " is not a client");
throw new IllegalArgumentException("Socket with id " + id + " is not a client");
}
return (TcpSocketClient) socket;
}

private TcpSocketServer getTcpServer(final int id) {
TcpSocket socket = socketMap.get(id);
if (socket == null) {
throw new IllegalArgumentException(TAG + "No socket with id " + id);
throw new IllegalArgumentException("No server socket with id " + id);
}
if (!(socket instanceof TcpSocketServer)) {
throw new IllegalArgumentException(TAG + "Socket with id " + id + " is not a server");
throw new IllegalArgumentException("Server socket with id " + id + " is not a server");
}
return (TcpSocketServer) socket;
}
Expand Down
Loading