Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SocketWrapper - copyable networking clients #768

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libraries/Ethernet/src/EthernetClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
#define ethernetclient_h

#include "Ethernet.h"
#include "MbedClient.h"
#include "AClient.h"

namespace arduino {

class EthernetClient : public MbedClient {
class EthernetClient : public AClient {
NetworkInterface *getNetwork() {
return Ethernet.getNetwork();
}
Expand Down
4 changes: 2 additions & 2 deletions libraries/Ethernet/src/EthernetSSLClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
#define ETHERNETSSLCLIENT_H

#include "EthernetClient.h"
#include "MbedSSLClient.h"
#include "AClient.h"

extern const char CA_CERTIFICATES[];

namespace arduino {

class EthernetSSLClient : public arduino::MbedSSLClient {
class EthernetSSLClient : public arduino::ASslClient {
NetworkInterface *getNetwork() {
return Ethernet.getNetwork();
}
Expand Down
2 changes: 1 addition & 1 deletion libraries/GSM/src/GSMClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

#include "GSMClient.h"

arduino::GSMClient::GSMClient(): MbedClient(100) {
arduino::GSMClient::GSMClient(): AClient(100) {

}
4 changes: 2 additions & 2 deletions libraries/GSM/src/GSMClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
#define gsmclient_h

#include "GSM.h"
#include "MbedClient.h"
#include "AClient.h"

namespace arduino {

class GSMClient : public MbedClient {
class GSMClient : public AClient {
public:
GSMClient();

Expand Down
2 changes: 1 addition & 1 deletion libraries/GSM/src/GSMSSLClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

#include "GSMSSLClient.h"

arduino::GSMSSLClient::GSMSSLClient(): MbedSSLClient(100) {
arduino::GSMSSLClient::GSMSSLClient(): ASslClient(100) {

}
4 changes: 2 additions & 2 deletions libraries/GSM/src/GSMSSLClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
#define GSMSSLCLIENT_H

#include "GSM.h"
#include "MbedSSLClient.h"
#include "AClient.h"

extern const char CA_CERTIFICATES[];

namespace arduino {

class GSMSSLClient : public arduino::MbedSSLClient {
class GSMSSLClient : public arduino::ASslClient {
public:
GSMSSLClient();

Expand Down
18 changes: 15 additions & 3 deletions libraries/SE05X/src/WiFiSSLSE050Client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,25 @@

#include "WiFiSSLSE050Client.h"

arduino::WiFiSSLSE050Client::WiFiSSLSE050Client() {
onBeforeConnect(mbed::callback(this, &WiFiSSLSE050Client::setRootCAClientCertKey));
arduino::MbedSSLSE050Client::MbedSSLSE050Client() {
onBeforeConnect(mbed::callback(this, &MbedSSLSE050Client::setRootCAClientCertKey));
};

void arduino::WiFiSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) {
void arduino::MbedSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) {

_keySlot = KeySlot;
_client_cert_len = certLen;
_client_cert = cert;
}

void WiFiSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) {
if (!client) {
newMbedClient();
}
static_cast<MbedSSLSE050Client*>(client.get())->setEccSlot(KeySlot, cert, certLen);
}

void WiFiSSLSE050Client::newMbedClient() {
client.reset(new MbedSSLSE050Client());
client->setNetwork(getNetwork());
}
24 changes: 18 additions & 6 deletions libraries/SE05X/src/WiFiSSLSE050Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

#include "SE05X.h"
#include "WiFiSSLClient.h"
#include "MbedSSLClient.h"

extern const char CA_CERTIFICATES[];

namespace arduino {

class WiFiSSLSE050Client : public arduino::WiFiSSLClient {
class MbedSSLSE050Client : public arduino::MbedSSLClient {

public:
WiFiSSLSE050Client();
virtual ~WiFiSSLSE050Client() {
stop();
}
MbedSSLSE050Client();

void setEccSlot(int KeySlot, const byte cert[], int certLen);

private:
Expand All @@ -57,14 +56,27 @@ class WiFiSSLSE050Client : public arduino::WiFiSSLClient {
return 0;
}

if( NSAPI_ERROR_OK != ((TLSSocket*)sock)->set_client_cert_key((void*)_client_cert, (size_t)_client_cert_len, &_keyObject, SE05X.getDeviceCtx())) {
if( NSAPI_ERROR_OK != ((TLSSocket*)sock)->set_client_cert_key((void*)_client_cert,
(size_t)_client_cert_len,
&_keyObject,
SE05X.getDeviceCtx())) {
return 0;
}

return 1;
}
};

class WiFiSSLSE050Client : public arduino::WiFiSSLClient {

public:

void setEccSlot(int KeySlot, const byte cert[], int certLen);
JAndrassy marked this conversation as resolved.
Show resolved Hide resolved

protected:
virtual void newMbedClient();
};

}

#endif /* WIFISSLSE050CLIENT_H */
149 changes: 149 additions & 0 deletions libraries/SocketWrapper/src/AClient.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@

#include "AClient.h"
#include "MbedSSLClient.h"

AClient::AClient(unsigned long timeout) {
setSocketTimeout(timeout);
}

void arduino::AClient::newMbedClient() {
client.reset(new MbedClient());
client->setNetwork(getNetwork());
}

arduino::AClient::operator bool() {
return client && *client;
}

void arduino::AClient::setSocket(Socket *sock) {
if (!client) {
newMbedClient();
}
client->setSocket(sock);
}

void arduino::AClient::setSocketTimeout(unsigned long timeout) {
if (!client) {
newMbedClient();
}
client->setSocketTimeout(timeout);
}

int arduino::AClient::connect(IPAddress ip, uint16_t port) {
if (!client) {
newMbedClient();
}
return client->connect(ip, port);
}

int arduino::AClient::connect(const char *host, uint16_t port) {
if (!client) {
newMbedClient();
}
return client->connect(host, port);
}

int arduino::AClient::connectSSL(IPAddress ip, uint16_t port) {
if (!client) {
newMbedClient();
}
return client->connectSSL(ip, port);
}

int arduino::AClient::connectSSL(const char *host, uint16_t port, bool disableSNI) {
if (!client) {
newMbedClient();
}
return client->connectSSL(host, port, disableSNI);
}

void arduino::AClient::stop() {
if (!client)
return;
client->stop();
}

uint8_t arduino::AClient::connected() {
if (!client)
return false;
return client->connected();
}

uint8_t arduino::AClient::status() {
if (!client)
return false;
return client->status();
}

IPAddress arduino::AClient::remoteIP() {
if (!client)
return INADDR_NONE;
return client->remoteIP();
}

uint16_t arduino::AClient::remotePort() {
if (!client)
return 0;
return client->remotePort();
}

size_t arduino::AClient::write(uint8_t b) {
if (!client)
return 0;
return client->write(b);
}

size_t arduino::AClient::write(const uint8_t *buf, size_t size) {
if (!client)
return 0;
return client->write(buf, size);
}

void arduino::AClient::flush() {
if (!client)
return;
client->flush();
}

int arduino::AClient::available() {
if (!client)
return 0;
return client->available();
}

int arduino::AClient::read() {
if (!client)
return -1;
return client->read();
}

int arduino::AClient::read(uint8_t *buf, size_t size) {
if (!client)
return 0;
return client->read(buf, size);
}

int arduino::AClient::peek() {
if (!client)
return -1;
return client->peek();
}

void arduino::ASslClient::newMbedClient() {
client.reset(new MbedSSLClient());
client->setNetwork(getNetwork());
}

void arduino::ASslClient::disableSNI(bool statusSNI) {
if (!client) {
newMbedClient();
}
static_cast<MbedSSLClient*>(client.get())->disableSNI(statusSNI);
}

void arduino::ASslClient::appendCustomCACert(const char* ca_cert) {
if (!client) {
newMbedClient();
}
static_cast<MbedSSLClient*>(client.get())->appendCustomCACert(ca_cert);
}
85 changes: 85 additions & 0 deletions libraries/SocketWrapper/src/AClient.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
AClient.h - Copyable Client implementation for Mbed Core

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
*/

#ifndef MBEDACLIENT_H
#define MBEDACLIENT_H

#include <Arduino.h>
#include "MbedClient.h"

namespace arduino {

class AClient : public Client {
public:

AClient() {}
AClient(unsigned long timeout);

virtual int connect(IPAddress ip, uint16_t port);
virtual int connect(const char *host, uint16_t port);
int connectSSL(IPAddress ip, uint16_t port);
int connectSSL(const char* host, uint16_t port, bool disableSNI = false);
virtual void stop();

virtual explicit operator bool();
virtual uint8_t connected();
uint8_t status();

IPAddress remoteIP();
uint16_t remotePort();

virtual size_t write(uint8_t);
virtual size_t write(const uint8_t *buf, size_t size);
virtual void flush();

virtual int available();
virtual int read();
virtual int read(uint8_t *buf, size_t size);
virtual int peek();

using Print::write;

void setSocketTimeout(unsigned long timeout);

protected:
friend class EthernetServer;
friend class WiFiServer;

std::shared_ptr<MbedClient> client;
virtual NetworkInterface* getNetwork() = 0;
virtual void newMbedClient();
void setSocket(Socket* sock);

};

class ASslClient : public AClient {
public:

ASslClient() {}
ASslClient(unsigned long timeout) : AClient(timeout) {}

void disableSNI(bool statusSNI);

void appendCustomCACert(const char* ca_cert);

protected:
virtual void newMbedClient();
};

}
#endif
Loading
Loading