diff --git a/Net/src/PollSet.cpp b/Net/src/PollSet.cpp index 487f067e4c..f0319120d1 100644 --- a/Net/src/PollSet.cpp +++ b/Net/src/PollSet.cpp @@ -263,10 +263,8 @@ class PollSetImpl if (it->fd == fd) { it->events = 0; - if (mode & PollSet::POLL_READ) - it->events |= POLLIN; - if (mode & PollSet::POLL_WRITE) - it->events |= POLLOUT; + it->revents = 0; + setMode(it->fd, it->events, mode); } } } @@ -307,11 +305,7 @@ class PollSetImpl pfd.fd = it->first; pfd.events = 0; pfd.revents = 0; - if (it->second & PollSet::POLL_READ) - pfd.events |= POLLIN; - if (it->second & PollSet::POLL_WRITE) - pfd.events |= POLLOUT; - + setMode(pfd.fd, pfd.events, it->second); _pollfds.push_back(pfd); } _addMap.clear(); @@ -325,9 +319,15 @@ class PollSetImpl { Poco::Timestamp start; #ifdef _WIN32 - rc = WSAPoll(&_pollfds[0], static_cast(_pollfds.size()), static_cast(timeout.totalMilliseconds())); + rc = WSAPoll(&_pollfds[0], static_cast(_pollfds.size()), static_cast(remainingTime.totalMilliseconds())); + // see https://github.com/pocoproject/poco/issues/3248 + if ((remainingTime > 0) && (rc > 0) && !hasSignaledFDs()) + { + rc = -1; + WSASetLastError(WSAEINTR); + } #else - rc = ::poll(&_pollfds[0], _pollfds.size(), timeout.totalMilliseconds()); + rc = ::poll(&_pollfds[0], _pollfds.size(), remainingTime.totalMilliseconds()); #endif if (rc < 0 && SocketImpl::lastError() == POCO_EINTR) { @@ -352,16 +352,20 @@ class PollSetImpl std::map::const_iterator its = _socketMap.find(it->fd); if (its != _socketMap.end()) { - if (it->revents & POLLIN) + if ((it->revents & POLLIN) +#ifdef _WIN32 + || (it->revents & POLLHUP) +#endif + ) result[its->second] |= PollSet::POLL_READ; - if (it->revents & POLLOUT) + if ((it->revents & POLLOUT) +#ifdef _WIN32 + && (_wantPOLLOUT.find(it->fd) != _wantPOLLOUT.end()) +#endif + ) result[its->second] |= PollSet::POLL_WRITE; if (it->revents & POLLERR) result[its->second] |= PollSet::POLL_ERROR; -#ifdef _WIN32 - if (it->revents & POLLHUP) - result[its->second] |= PollSet::POLL_READ; -#endif } it->revents = 0; } @@ -372,8 +376,52 @@ class PollSetImpl } private: + +#ifdef _WIN32 + + void setMode(poco_socket_t fd, short& target, int mode) + { + if (mode & PollSet::POLL_READ) + target |= POLLIN; + + if (mode & PollSet::POLL_WRITE) + _wantPOLLOUT.insert(fd); + else + _wantPOLLOUT.erase(fd); + target |= POLLOUT; + } + + bool hasSignaledFDs() + { + for (const auto& pollfd : _pollfds) + { + if ((pollfd.revents | POLLOUT) && + (_wantPOLLOUT.find(pollfd.fd) != _wantPOLLOUT.end())) + { + return true; + } + } + return false; + } + +#else + + void setMode(poco_socket_t fd, short& target, int mode) + { + if (mode & PollSet::POLL_READ) + target |= POLLIN; + + if (mode & PollSet::POLL_WRITE) + target |= POLLOUT; + } + +#endif + mutable Poco::FastMutex _mutex; std::map _socketMap; +#ifdef _WIN32 + std::set _wantPOLLOUT; +#endif std::map _addMap; std::set _removeSet; std::vector _pollfds; diff --git a/Net/testsuite/src/EchoServer.cpp b/Net/testsuite/src/EchoServer.cpp index 74a5370fd0..7db735ef2a 100644 --- a/Net/testsuite/src/EchoServer.cpp +++ b/Net/testsuite/src/EchoServer.cpp @@ -23,7 +23,8 @@ using Poco::Net::SocketAddress; EchoServer::EchoServer(): _socket(SocketAddress()), _thread("EchoServer"), - _stop(false) + _stop(false), + _done(false) { _thread.start(*this); _ready.wait(); @@ -33,7 +34,8 @@ EchoServer::EchoServer(): EchoServer::EchoServer(const Poco::Net::SocketAddress& address): _socket(address), _thread("EchoServer"), - _stop(false) + _stop(false), + _done(false) { _thread.start(*this); _ready.wait(); @@ -78,5 +80,18 @@ void EchoServer::run() } } } + _done = true; +} + + +void EchoServer::stop() +{ + _stop = true; +} + + +bool EchoServer::done() +{ + return _done; } diff --git a/Net/testsuite/src/EchoServer.h b/Net/testsuite/src/EchoServer.h index d724d495ac..4ff35baee6 100644 --- a/Net/testsuite/src/EchoServer.h +++ b/Net/testsuite/src/EchoServer.h @@ -36,15 +36,22 @@ class EchoServer: public Poco::Runnable Poco::UInt16 port() const; /// Returns the port the echo server is /// listening on. - + void run(); /// Does the work. - + + void stop(); + /// Sets the stop flag. + + bool done(); + /// Retruns true if if server is done. + private: Poco::Net::ServerSocket _socket; Poco::Thread _thread; Poco::Event _ready; bool _stop; + bool _done; }; diff --git a/Net/testsuite/src/PollSetTest.cpp b/Net/testsuite/src/PollSetTest.cpp index 9cb37d43c2..774b44dced 100644 --- a/Net/testsuite/src/PollSetTest.cpp +++ b/Net/testsuite/src/PollSetTest.cpp @@ -28,6 +28,7 @@ using Poco::Net::ConnectionRefusedException; using Poco::Net::PollSet; using Poco::Timespan; using Poco::Stopwatch; +using Poco::Thread; PollSetTest::PollSetTest(const std::string& name): CppUnit::TestCase(name) @@ -76,7 +77,7 @@ void PollSetTest::testPoll() assertTrue (sm.find(ss1) != sm.end()); assertTrue (sm.find(ss2) == sm.end()); assertTrue (sm.find(ss1)->second == PollSet::POLL_WRITE); - assertTrue (sw.elapsed() < 100000); + assertTrue (sw.elapsed() < 1100000); ps.update(ss1, PollSet::POLL_READ); @@ -87,7 +88,7 @@ void PollSetTest::testPoll() assertTrue (sm.find(ss1) != sm.end()); assertTrue (sm.find(ss2) == sm.end()); assertTrue (sm.find(ss1)->second == PollSet::POLL_READ); - assertTrue (sw.elapsed() < 100000); + assertTrue (sw.elapsed() < 1100000); int n = ss1.receiveBytes(buffer, sizeof(buffer)); assertTrue (n == 5); @@ -100,7 +101,7 @@ void PollSetTest::testPoll() assertTrue (sm.find(ss1) == sm.end()); assertTrue (sm.find(ss2) != sm.end()); assertTrue (sm.find(ss2)->second == PollSet::POLL_READ); - assertTrue (sw.elapsed() < 100000); + assertTrue (sw.elapsed() < 1100000); n = ss2.receiveBytes(buffer, sizeof(buffer)); assertTrue (n == 5); @@ -125,6 +126,69 @@ void PollSetTest::testPoll() } +void PollSetTest::testPollNoServer() +{ + StreamSocket ss1; + StreamSocket ss2; + + ss1.connectNB(SocketAddress("127.0.0.1", 0xFEFE)); + ss2.connectNB(SocketAddress("127.0.0.1", 0xFEFF)); + PollSet ps; + assertTrue(ps.empty()); + ps.add(ss1, PollSet::POLL_READ); + ps.add(ss2, PollSet::POLL_READ); + assertTrue(!ps.empty()); + assertTrue(ps.has(ss1)); + assertTrue(ps.has(ss2)); + PollSet::SocketModeMap sm; + Stopwatch sw; sw.start(); + do + { + sm = ps.poll(Timespan(1000000)); + if (sw.elapsedSeconds() > 10) fail(); + } while (sm.size() < 2); + assertTrue(sm.size() == 2); + for (auto s : sm) + assertTrue(0 != (s.second | PollSet::POLL_ERROR)); +} + + +void PollSetTest::testPollClosedServer() +{ + EchoServer echoServer1; + EchoServer echoServer2; + StreamSocket ss1; + StreamSocket ss2; + + ss1.connectNB(SocketAddress("127.0.0.1", echoServer1.port())); + ss2.connectNB(SocketAddress("127.0.0.1", echoServer2.port())); + PollSet ps; + assertTrue(ps.empty()); + ps.add(ss1, PollSet::POLL_READ); + ps.add(ss2, PollSet::POLL_READ); + assertTrue(!ps.empty()); + assertTrue(ps.has(ss1)); + assertTrue(ps.has(ss2)); + + echoServer1.stop(); + ss1.sendBytes("HELLO", 5); + while (!echoServer1.done()) Thread::sleep(10); + echoServer2.stop(); + ss2.sendBytes("HELLO", 5); + while (!echoServer2.done()) Thread::sleep(10); + PollSet::SocketModeMap sm; + Stopwatch sw; sw.start(); + do + { + sm = ps.poll(Timespan(1000000)); + if (sw.elapsedSeconds() > 10) fail(); + } while (sm.size() < 2); + assertTrue(sm.size() == 2); + assertTrue(0 == ss1.receiveBytes(0, 0)); + assertTrue(0 == ss2.receiveBytes(0, 0)); +} + + void PollSetTest::setUp() { } @@ -140,6 +204,8 @@ CppUnit::Test* PollSetTest::suite() CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("PollSetTest"); CppUnit_addTest(pSuite, PollSetTest, testPoll); + CppUnit_addTest(pSuite, PollSetTest, testPollNoServer); + CppUnit_addTest(pSuite, PollSetTest, testPollClosedServer); return pSuite; } diff --git a/Net/testsuite/src/PollSetTest.h b/Net/testsuite/src/PollSetTest.h index 33a76a289a..9a8e70e64d 100644 --- a/Net/testsuite/src/PollSetTest.h +++ b/Net/testsuite/src/PollSetTest.h @@ -25,6 +25,8 @@ class PollSetTest: public CppUnit::TestCase ~PollSetTest(); void testPoll(); + void testPollNoServer(); + void testPollClosedServer(); void setUp(); void tearDown();