diff --git a/libraries/Ethernet/src/EthernetClient.h b/libraries/Ethernet/src/EthernetClient.h index 9d9f9b77a..4eae838a8 100644 --- a/libraries/Ethernet/src/EthernetClient.h +++ b/libraries/Ethernet/src/EthernetClient.h @@ -21,11 +21,11 @@ #define ethernetclient_h #include "Ethernet.h" -#include "MbedClient.h" +#include "AClient.h" namespace arduino { -class EthernetClient : public MbedClient { +class EthernetClient : public AClient { NetworkInterface *getNetwork() { return Ethernet.getNetwork(); } diff --git a/libraries/Ethernet/src/EthernetSSLClient.h b/libraries/Ethernet/src/EthernetSSLClient.h index aa6fbfebc..79f26d3ed 100644 --- a/libraries/Ethernet/src/EthernetSSLClient.h +++ b/libraries/Ethernet/src/EthernetSSLClient.h @@ -21,13 +21,13 @@ #define ETHERNETSSLCLIENT_H #include "EthernetClient.h" -#include "MbedSSLClient.h" +#include "AClient.h" extern const char CA_CERTIFICATES[]; namespace arduino { -class EthernetSSLClient : public arduino::MbedSSLClient { +class EthernetSSLClient : public arduino::ASslClient { NetworkInterface *getNetwork() { return Ethernet.getNetwork(); } diff --git a/libraries/GSM/src/GSMClient.cpp b/libraries/GSM/src/GSMClient.cpp index 71043da70..2db67df7d 100644 --- a/libraries/GSM/src/GSMClient.cpp +++ b/libraries/GSM/src/GSMClient.cpp @@ -19,6 +19,6 @@ #include "GSMClient.h" -arduino::GSMClient::GSMClient(): MbedClient(100) { +arduino::GSMClient::GSMClient(): AClient(100) { } diff --git a/libraries/GSM/src/GSMClient.h b/libraries/GSM/src/GSMClient.h index 8ac465975..52c794a75 100644 --- a/libraries/GSM/src/GSMClient.h +++ b/libraries/GSM/src/GSMClient.h @@ -21,11 +21,11 @@ #define gsmclient_h #include "GSM.h" -#include "MbedClient.h" +#include "AClient.h" namespace arduino { -class GSMClient : public MbedClient { +class GSMClient : public AClient { public: GSMClient(); diff --git a/libraries/GSM/src/GSMSSLClient.cpp b/libraries/GSM/src/GSMSSLClient.cpp index c953adb4f..0070f210e 100644 --- a/libraries/GSM/src/GSMSSLClient.cpp +++ b/libraries/GSM/src/GSMSSLClient.cpp @@ -19,6 +19,6 @@ #include "GSMSSLClient.h" -arduino::GSMSSLClient::GSMSSLClient(): MbedSSLClient(100) { +arduino::GSMSSLClient::GSMSSLClient(): ASslClient(100) { } diff --git a/libraries/GSM/src/GSMSSLClient.h b/libraries/GSM/src/GSMSSLClient.h index 2ea0ae713..ab07d1f79 100644 --- a/libraries/GSM/src/GSMSSLClient.h +++ b/libraries/GSM/src/GSMSSLClient.h @@ -21,13 +21,13 @@ #define GSMSSLCLIENT_H #include "GSM.h" -#include "MbedSSLClient.h" +#include "AClient.h" extern const char CA_CERTIFICATES[]; namespace arduino { -class GSMSSLClient : public arduino::MbedSSLClient { +class GSMSSLClient : public arduino::ASslClient { public: GSMSSLClient(); diff --git a/libraries/SE05X/src/WiFiSSLSE050Client.cpp b/libraries/SE05X/src/WiFiSSLSE050Client.cpp index e0d78db5d..7a3b88555 100644 --- a/libraries/SE05X/src/WiFiSSLSE050Client.cpp +++ b/libraries/SE05X/src/WiFiSSLSE050Client.cpp @@ -19,13 +19,25 @@ #include "WiFiSSLSE050Client.h" -arduino::WiFiSSLSE050Client::WiFiSSLSE050Client() { - onBeforeConnect(mbed::callback(this, &WiFiSSLSE050Client::setRootCAClientCertKey)); +arduino::MbedSSLSE050Client::MbedSSLSE050Client() { + onBeforeConnect(mbed::callback(this, &MbedSSLSE050Client::setRootCAClientCertKey)); }; -void arduino::WiFiSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) { +void arduino::MbedSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) { _keySlot = KeySlot; _client_cert_len = certLen; _client_cert = cert; } + +void WiFiSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) { + if (!client) { + newMbedClient(); + } + static_cast(client.get())->setEccSlot(KeySlot, cert, certLen); +} + +void WiFiSSLSE050Client::newMbedClient() { + client.reset(new MbedSSLSE050Client()); + client->setNetwork(getNetwork()); +} diff --git a/libraries/SE05X/src/WiFiSSLSE050Client.h b/libraries/SE05X/src/WiFiSSLSE050Client.h index cb223255f..c89e4b96e 100644 --- a/libraries/SE05X/src/WiFiSSLSE050Client.h +++ b/libraries/SE05X/src/WiFiSSLSE050Client.h @@ -23,18 +23,17 @@ #include "SE05X.h" #include "WiFiSSLClient.h" +#include "MbedSSLClient.h" extern const char CA_CERTIFICATES[]; namespace arduino { -class WiFiSSLSE050Client : public arduino::WiFiSSLClient { +class MbedSSLSE050Client : public arduino::MbedSSLClient { public: - WiFiSSLSE050Client(); - virtual ~WiFiSSLSE050Client() { - stop(); - } + MbedSSLSE050Client(); + void setEccSlot(int KeySlot, const byte cert[], int certLen); private: @@ -65,6 +64,14 @@ class WiFiSSLSE050Client : public arduino::WiFiSSLClient { } }; +class WiFiSSLSE050Client : public arduino::WiFiSSLClient { + + void setEccSlot(int KeySlot, const byte cert[], int certLen); + +protected: + virtual void newMbedClient(); +}; + } #endif /* WIFISSLSE050CLIENT_H */ diff --git a/libraries/SocketWrapper/src/AClient.cpp b/libraries/SocketWrapper/src/AClient.cpp new file mode 100644 index 000000000..a1e4fae46 --- /dev/null +++ b/libraries/SocketWrapper/src/AClient.cpp @@ -0,0 +1,143 @@ + +#include "AClient.h" +#include "MbedSSLClient.h" + +AClient::AClient(unsigned long timeout) { + setSocketTimeout(timeout); +} + +void arduino::AClient::newMbedClient() { + client.reset(new MbedClient()); + client->setNetwork(getNetwork()); +} + +arduino::AClient::operator bool() { + return client && *client; +} + +void arduino::AClient::setSocket(Socket *sock) { + if (!client) { + newMbedClient(); + } + client->setSocket(sock); +} + +void arduino::AClient::setSocketTimeout(unsigned long timeout) { + if (!client) { + newMbedClient(); + } + client->setTimeout(timeout); +} + +int arduino::AClient::connect(IPAddress ip, uint16_t port) { + if (!client) { + newMbedClient(); + } + return client->connect(ip, port); +} + +int arduino::AClient::connect(const char *host, uint16_t port) { + if (!client) { + newMbedClient(); + } + return client->connect(host, port); +} + +int arduino::AClient::connectSSL(IPAddress ip, uint16_t port) { + if (!client) { + newMbedClient(); + } + return client->connectSSL(ip, port); +} + +int arduino::AClient::connectSSL(const char *host, uint16_t port, bool disableSNI) { + if (!client) { + newMbedClient(); + } + return client->connectSSL(host, port, disableSNI); +} + +void arduino::AClient::stop() { + if (!client) + return; + client->stop(); +} + +uint8_t arduino::AClient::connected() { + if (!client) + return false; + return client->connected(); +} + +IPAddress arduino::AClient::remoteIP() { + if (!client) + return INADDR_NONE; + return client->remoteIP(); +} + +uint16_t arduino::AClient::remotePort() { + if (!client) + return 0; + return client->remotePort(); +} + +size_t arduino::AClient::write(uint8_t b) { + if (!client) + return 0; + return client->write(b); +} + +size_t arduino::AClient::write(const uint8_t *buf, size_t size) { + if (!client) + return 0; + return client->write(buf, size); +} + +void arduino::AClient::flush() { + if (!client) + return; + client->flush(); +} + +int arduino::AClient::available() { + if (!client) + return 0; + return client->available(); +} + +int arduino::AClient::read() { + if (!client) + return -1; + return client->read(); +} + +int arduino::AClient::read(uint8_t *buf, size_t size) { + if (!client) + return 0; + return client->read(buf, size); +} + +int arduino::AClient::peek() { + if (!client) + return -1; + return client->peek(); +} + +void arduino::ASslClient::newMbedClient() { + client.reset(new MbedSSLClient()); + client->setNetwork(getNetwork()); +} + +void arduino::ASslClient::disableSNI(bool statusSNI) { + if (!client) { + newMbedClient(); + } + static_cast(client.get())->disableSNI(statusSNI); +} + +void arduino::ASslClient::appendCustomCACert(const char* ca_cert) { + if (!client) { + newMbedClient(); + } + static_cast(client.get())->appendCustomCACert(ca_cert); +} diff --git a/libraries/SocketWrapper/src/AClient.h b/libraries/SocketWrapper/src/AClient.h new file mode 100644 index 000000000..c93bea0f3 --- /dev/null +++ b/libraries/SocketWrapper/src/AClient.h @@ -0,0 +1,85 @@ +/* + AClient.h - Copyable Client implementation for Mbed Core + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA +*/ + +#ifndef MBEDACLIENT_H +#define MBEDACLIENT_H + +#include +#include "MbedClient.h" + +namespace arduino { + +class AClient : public Client { +public: + + AClient() {} + AClient(unsigned long timeout); + + virtual int connect(IPAddress ip, uint16_t port); + virtual int connect(const char *host, uint16_t port); + int connectSSL(IPAddress ip, uint16_t port); + int connectSSL(const char* host, uint16_t port, bool disableSNI = false); + virtual void stop(); + + virtual explicit operator bool(); + virtual uint8_t connected(); + uint8_t status(); + + IPAddress remoteIP(); + uint16_t remotePort(); + + virtual size_t write(uint8_t); + virtual size_t write(const uint8_t *buf, size_t size); + virtual void flush(); + + virtual int available(); + virtual int read(); + virtual int read(uint8_t *buf, size_t size); + virtual int peek(); + + using Print::write; + + void setSocketTimeout(unsigned long timeout); + +protected: + friend class EthernetServer; + friend class WiFiServer; + + std::shared_ptr client; + virtual NetworkInterface* getNetwork() = 0; + virtual void newMbedClient(); + void setSocket(Socket* sock); + +}; + +class ASslClient : public AClient { +public: + + ASslClient() {} + ASslClient(unsigned long timeout) : AClient(timeout) {} + + void disableSNI(bool statusSNI); + + void appendCustomCACert(const char* ca_cert); + +protected: + virtual void newMbedClient(); +}; + +} +#endif diff --git a/libraries/SocketWrapper/src/MbedClient.cpp b/libraries/SocketWrapper/src/MbedClient.cpp index 49265c002..b25c851cb 100644 --- a/libraries/SocketWrapper/src/MbedClient.cpp +++ b/libraries/SocketWrapper/src/MbedClient.cpp @@ -24,7 +24,7 @@ void arduino::MbedClient::readSocket() { continue; } mutex->lock(); - if (sock == nullptr || (closing && borrowed_socket)) { + if (sock == nullptr) { goto cleanup; } ret = sock->recv(data, rxBuffer.availableForStore()); @@ -270,7 +270,7 @@ void arduino::MbedClient::stop() { if (mutex != nullptr) { mutex->lock(); } - if (sock != nullptr && borrowed_socket == false) { + if (sock != nullptr) { if (_own_socket) { delete sock; } else { @@ -278,7 +278,6 @@ void arduino::MbedClient::stop() { } sock = nullptr; } - closing = true; if (mutex != nullptr) { mutex->unlock(); } diff --git a/libraries/SocketWrapper/src/MbedClient.h b/libraries/SocketWrapper/src/MbedClient.h index eca0e5a34..9db17fa78 100644 --- a/libraries/SocketWrapper/src/MbedClient.h +++ b/libraries/SocketWrapper/src/MbedClient.h @@ -35,16 +35,7 @@ namespace arduino { -class MbedClient : public arduino::Client { -private: - // Helper for copy constructor and assignment operator - void copyClient(const MbedClient& orig) { - auto _sock = orig.sock; - auto _m = (MbedClient*)&orig; - _m->borrowed_socket = true; - _m->stop(); - this->setSocket(_sock); - } +class MbedClient { public: MbedClient(); @@ -53,31 +44,21 @@ class MbedClient : public arduino::Client { _timeout = timeout; } - // Copy constructor, to be used when a Client returned by server.available() - // needs to "survive" event if it goes out of scope - // Sample usage: Client* new_client = new Client(existing_client) - MbedClient(const MbedClient& orig) { - copyClient(orig); - } - - MbedClient& operator=(const MbedClient& orig) { - copyClient(orig); - return *this; - } - virtual ~MbedClient() { stop(); } + void setNetwork(NetworkInterface* network) {_network = network;} + uint8_t status(); int connect(SocketAddress socketAddress); - int connect(IPAddress ip, uint16_t port); - int connect(const char* host, uint16_t port); + virtual int connect(IPAddress ip, uint16_t port); + virtual int connect(const char* host, uint16_t port); int connectSSL(SocketAddress socketAddress); int connectSSL(IPAddress ip, uint16_t port); int connectSSL(const char* host, uint16_t port, bool disableSNI = false); size_t write(uint8_t); - size_t write(const uint8_t* buf, size_t size) override; + size_t write(const uint8_t* buf, size_t size); int available(); int read(); int read(uint8_t* buf, size_t size); @@ -103,10 +84,10 @@ class MbedClient : public arduino::Client { friend class MbedSSLClient; friend class MbedSocketClass; - using Print::write; - protected: - virtual NetworkInterface* getNetwork() = 0; + NetworkInterface* getNetwork() {return _network;} + + NetworkInterface* _network = nullptr; Socket* sock = nullptr; void onBeforeConnect(mbed::Callback cb) { @@ -114,11 +95,12 @@ class MbedClient : public arduino::Client { } private: + + MbedClient(const MbedClient&) : _timeout(0) {} + RingBufferN rxBuffer; bool _status = false; - bool borrowed_socket = false; bool _own_socket = false; - bool closing = false; mbed::Callback beforeConnect; SocketAddress address; rtos::Thread* reader_th = nullptr; diff --git a/libraries/WiFi/src/WiFiClient.h b/libraries/WiFi/src/WiFiClient.h index f60a978d3..0cb6781ab 100644 --- a/libraries/WiFi/src/WiFiClient.h +++ b/libraries/WiFi/src/WiFiClient.h @@ -21,11 +21,11 @@ #define wificlient_h #include "WiFi.h" -#include "MbedClient.h" +#include "AClient.h" namespace arduino { -class WiFiClient : public MbedClient { +class WiFiClient : public AClient { NetworkInterface *getNetwork() { return WiFi.getNetwork(); } diff --git a/libraries/WiFi/src/WiFiSSLClient.h b/libraries/WiFi/src/WiFiSSLClient.h index c4751e10a..366eda4b5 100644 --- a/libraries/WiFi/src/WiFiSSLClient.h +++ b/libraries/WiFi/src/WiFiSSLClient.h @@ -21,13 +21,14 @@ #define WIFISSLCLIENT_H #include "WiFi.h" -#include "MbedSSLClient.h" +#include "AClient.h" extern const char CA_CERTIFICATES[]; namespace arduino { -class WiFiSSLClient : public arduino::MbedSSLClient { +class WiFiSSLClient : public arduino::ASslClient { +protected: NetworkInterface *getNetwork() { return WiFi.getNetwork(); }