From 30ec64566b87cafc2fb192bf0e550a23e5a31f52 Mon Sep 17 00:00:00 2001 From: Tor Lillqvist Date: Sat, 15 Sep 2018 23:29:39 +0300 Subject: [PATCH] Attempt to fix various FakeSocket problems Also add a bit of internals documentation. --- Mobile/TestFakeSocket/TestFakeSocket/main.cpp | 111 ++++++++- net/FakeSocket.cpp | 228 ++++++++++++++---- 2 files changed, 274 insertions(+), 65 deletions(-) diff --git a/Mobile/TestFakeSocket/TestFakeSocket/main.cpp b/Mobile/TestFakeSocket/TestFakeSocket/main.cpp index 22e46d85c..2e18cd8f1 100644 --- a/Mobile/TestFakeSocket/TestFakeSocket/main.cpp +++ b/Mobile/TestFakeSocket/TestFakeSocket/main.cpp @@ -20,7 +20,7 @@ int main(int argc, char **argv) int s1 = fakeSocketSocket(); int s2 = fakeSocketSocket(); - std::cout << "sockets: " << s0 << ", " << s1 << ", " << s2 << "\n"; + std::cout << "sockets: s0=" << s0 << ", s1=" << s1 << ", s2=" << s2 << "\n"; fakeSocketClose(s1); @@ -30,15 +30,26 @@ int main(int argc, char **argv) int rc = fakeSocketListen(s0); if (rc == -1) { - perror("listen"); + perror("listening on s0"); return 1; } - int s3; + int s3, s4; std::thread t0([&] { s3 = fakeSocketAccept4(s0, 0); if (s3 == -1) + { perror("accept"); + return; + } + std::cout << "accepted s3=" << s3 << " from s0\n"; + s4 = fakeSocketAccept4(s0, 0); + if (s4 == -1) + { + perror("accept"); + return; + } + std::cout << "accepted s4=" << s4 << " from s0\n"; }); rc = fakeSocketConnect(s1, s0); @@ -47,20 +58,38 @@ int main(int argc, char **argv) perror("connect"); return 1; } + std::cout << "connected s1\n"; + + rc = fakeSocketConnect(s2, s0); + if (rc == -1) + { + perror("connect"); + return 1; + } + std::cout << "connected s2\n"; t0.join(); - if (s3 == -1) + if (s3 == -1 || s4 == -1) return 1; - rc = fakeSocketWrite(s1, "hello", 6); + rc = fakeSocketWrite(s1, "hello", 5); if (rc == -1) { perror("write"); return 1; } - std::cout << "wrote 'hello'\n"; + std::cout << "wrote 'hello' to s1\n"; + + rc = fakeSocketWrite(s2, "moin", 4); + if (rc == -1) + { + perror("write"); + return 1; + } + std::cout << "wrote 'moin' to s2\n"; char buf[100]; + rc = fakeSocketRead(s3, buf, 100); if (rc == -1) { @@ -68,31 +97,40 @@ int main(int argc, char **argv) return 1; } buf[rc] = 0; - std::cout << "read " << buf << "\n"; + std::cout << "read " << buf << " from s3\n"; - rc = fakeSocketWrite(s1, "goodbye", 7); + rc = fakeSocketRead(s4, buf, 100); + if (rc == -1) + { + perror("read"); + return 1; + } + buf[rc] = 0; + std::cout << "read '" << buf << "' from s4\n"; + + rc = fakeSocketWrite(s3, "goodbye", 7); if (rc == -1) { perror("write"); return 1; } - std::cout << "wrote 'goodbye'\n"; + std::cout << "wrote 'goodbye' to s3\n"; - rc = fakeSocketRead(s3, buf, 4); + rc = fakeSocketRead(s1, buf, 4); if (rc != -1) { std::cerr << "Tried partial read, and succeeded!?\n"; return 1; } - rc = fakeSocketRead(s3, buf, 100); + rc = fakeSocketRead(s1, buf, 100); if (rc == -1) { perror("read"); return 1; } buf[rc] = 0; - std::cout << "read " << buf << "\n"; + std::cout << "read '" << buf << "' from s1\n"; int pipe[2]; rc = fakeSocketPipe2(pipe); @@ -102,6 +140,35 @@ int main(int argc, char **argv) return 1; } + fakeSocketClose(s3); + std::cout << "closed s3\n"; + + rc = fakeSocketRead(s1, buf, 100); + if (rc == -1) + { + perror("read"); + return 1; + } + if (rc != 0) + { + std::cerr << "read '" << buf << "' from s1 after peer s3 was closed!?\n"; + return 1; + } + std::cout << "correctly got eof from s1\n"; + + rc = fakeSocketRead(s1, buf, 100); + if (rc == -1) + { + perror("read"); + return 1; + } + if (rc != 0) + { + std::cerr << "read '" << buf << "' from s1 after peer s3 was closed!?\n"; + return 1; + } + std::cout << "correctly got eof from s1\n"; + rc = fakeSocketWrite(pipe[0], "x", 1); if (rc == -1) { @@ -116,7 +183,25 @@ int main(int argc, char **argv) } if (buf[0] != 'x') { - std::cerr << "Wrote 'x' but read '" << buf[0] << "'\n"; + std::cerr << "wrote 'x' to pipe but read '" << buf[0] << "'\n"; + return 1; + } + + rc = fakeSocketWrite(pipe[1], "y", 1); + if (rc == -1) + { + perror("write"); + return 1; + } + rc = fakeSocketRead(pipe[0], buf, 1); + if (rc == -1) + { + perror("read"); + return 1; + } + if (buf[0] != 'y') + { + std::cerr << "wrote 'y' to pipe but read '" << buf[0] << "'\n"; return 1; } diff --git a/net/FakeSocket.cpp b/net/FakeSocket.cpp index e9681b392..43107ef3e 100644 --- a/net/FakeSocket.cpp +++ b/net/FakeSocket.cpp @@ -19,14 +19,33 @@ #include "FakeSocket.hpp" +// A "fake socket" is represented by a number, a smallish integer, just like a real socket. +// +// There is one FakeSocketPair for each two sequential fake socket numbers. When you create one, you +// will always get the lower (even) number in a pair. The higher number wil be returned if you +// sucessfully call fakeSocketConnect() from the lower number to some other fake socket. +// +// After you create a fake socket, there is basically just two things you can do with it: +// +// 1) Call fakeSocketConnect on it giving another fake socket number to connect to. Once the +// connection is successful, you can call fakeSocketRead() and fakeSocketWrite() on your original +// socket. +// +// 2) Call fakeSocketListen() on it, indicating it is a "server" socket. After that, keep calling +// fakeSocketAccept() and each time that returns successfully, it will return a new fake socket that +// is connected to another fake socket that called fakeSocketConnect() to the server socket. You can +// then call fakeSocketRead() and fakeSocketWrite() on it. +// +// This all is complicated a bit by the fact that all the API is non-blocking. + struct FakeSocketPair { int fd[2]; bool listening; int connectingFd; + bool readable[2]; std::vector buffer[2]; std::mutex *mutex; - // std::condition_variable *cv; FakeSocketPair() { @@ -34,12 +53,14 @@ struct FakeSocketPair fd[1] = -1; listening = false; connectingFd = -1; + readable[0] = false; + readable[1] = false; mutex = new std::mutex(); - // cv = new std::condition_variable(); } }; static std::mutex fdsMutex; +static std::mutex cvMutex; static std::condition_variable cv; // Avoid problems with order of initialisation of static globals. @@ -53,17 +74,16 @@ static std::vector& getFds() int fakeSocketSocket() { std::vector& fds = getFds(); - std::cerr << "----- &fds=" << &fds << " size=" << fds.size() << std::endl; std::lock_guard fdsLock(fdsMutex); - size_t i; - for (i = 0; i < fds.size(); i++) - { - if (fds[i].fd[0] == -1 && fds[i].fd[1] == -1) - break; - } - if (i == fds.size()) - fds.resize(fds.size() + 1); + + // We always allocate a new FakeSocketPair struct. Let's not bother with potential issues with + // reusing them. It isn't like we would be allocating thousands anyway during the typical + // lifetime of an app. + + const int i = fds.size(); + fds.resize(i + 1); + FakeSocketPair& result = fds[i]; result.fd[0] = i*2; @@ -98,32 +118,97 @@ int fakeSocketPipe2(int pipefd[2]) return 0; } -static bool someFdReadable(struct pollfd *pollfds, int nfds) +static std::string pollBits(int bits) +{ + if (bits == 0) + return "-"; + + std::string result; + + if (bits & POLLERR) + { + if (result != "") + result += "+"; + result += "ERR"; + } + if (bits & POLLHUP) + { + if (result != "") + result += "+"; + result += "HUP"; + } + if (bits & POLLIN) + { + if (result != "") + result += "+"; + result += "IN"; + } + if (bits & POLLNVAL) + { + if (result != "") + result += "+"; + result += "NVAL"; + } + if (bits & POLLOUT) + { + if (result != "") + result += "+"; + result += "OUT"; + } + if (bits & POLLPRI) + { + if (result != "") + result += "+"; + result += "PRI"; + } + + return result; +} + +static bool checkForPoll(std::vector& fds, struct pollfd *pollfds, int nfds) { - std::vector& fds = getFds(); bool retval = false; for (int i = 0; i < nfds; i++) { - pollfds[i].revents = 0; - const int K = ((pollfds[i].fd)&1); - if (pollfds[i].events & POLLIN) - if (fds[pollfds[i].fd/2].fd[K] != -1 && fds[pollfds[i].fd/2].buffer[K].size() > 0) + // Caller sets POLLNVAL for invalid fds. + if (pollfds[i].revents != POLLNVAL) + { + pollfds[i].revents = 0; + const int K = ((pollfds[i].fd)&1); + const int N = 1 - K; + if (pollfds[i].events & POLLIN) { - pollfds[i].revents = POLLIN; - retval = true; + if (fds[pollfds[i].fd/2].fd[K] != -1 && + (fds[pollfds[i].fd/2].readable[K] || + (K == 0 && fds[pollfds[i].fd/2].listening && fds[pollfds[i].fd/2].connectingFd != -1))) + { + pollfds[i].revents |= POLLIN; + retval = true; + } } + // With our trivial single-message buffering, a socket is writable if the peer socket is + // open and not readable. + if (pollfds[i].events & POLLOUT) + { + if (fds[pollfds[i].fd/2].fd[N] != -1 && !fds[pollfds[i].fd/2].readable[N]) + { + pollfds[i].revents |= POLLOUT; + retval = true; + } + } + } } return retval; } int fakeSocketPoll(struct pollfd *pollfds, int nfds, int timeout) { - std::cerr << "+++++ Polling " << nfds << " fds: "; + std::cerr << "+++++ Poll "; for (int i = 0; i < nfds; i++) { if (i > 0) std::cerr << ","; - std::cerr << pollfds[i].fd; + std::cerr << pollfds[i].fd << ":" << pollBits(pollfds[i].events); } std::cerr << "\n"; @@ -131,18 +216,34 @@ int fakeSocketPoll(struct pollfd *pollfds, int nfds, int timeout) std::unique_lock fdsLock(fdsMutex); for (int i = 0; i < nfds; i++) { - if (pollfds[i].fd < 1 || pollfds[i].fd/2 >= fds.size()) + if (pollfds[i].fd < 0 || pollfds[i].fd/2 >= fds.size()) { - errno = EBADF; - return -1; + pollfds[i].revents = POLLNVAL; + } + else + { + const int K = ((pollfds[i].fd)&1); + if (fds[pollfds[i].fd/2].fd[K] == -1) + pollfds[i].revents = POLLNVAL; + else + pollfds[i].revents = 0; } } - // Here we lock just the first FakeSocketPair struct, hmm - std::unique_lock fdLock(fds[pollfds[0].fd/2].mutex[0]); + + std::unique_lock cvLock(cvMutex); fdsLock.unlock(); - while (!someFdReadable(pollfds, nfds)) - cv.wait(fdLock); + while (!checkForPoll(fds, pollfds, nfds)) + cv.wait(cvLock); + + std::cerr << "+++++ Poll result: "; + for (int i = 0; i < nfds; i++) + { + if (i > 0) + std::cerr << ","; + std::cerr << pollfds[i].fd << ":" << pollBits(pollfds[i].revents); + } + std::cerr << "\n"; return 0; } @@ -151,7 +252,7 @@ int fakeSocketListen(int fd) { std::vector& fds = getFds(); std::unique_lock fdsLock(fdsMutex); - if (fd < 0 || fd/2 >= fds.size()) + if (fd < 0 || fd/2 >= fds.size() || fds[fd/2].fd[fd&1] == -1) { std::cerr << "+++++ EBADF: Listening on fd " << fd << "\n"; errno = EBADF; @@ -223,13 +324,14 @@ int fakeSocketConnect(int fd1, int fd2) } pair2.connectingFd = fd1; - // pair2.cv->notify_all(); cv.notify_all(); + std::unique_lock cvLock(cvMutex); fdLock2.unlock(); + fdLock1.unlock(); + while (pair1.fd[1] == -1) - // pair1.cv->wait(fdLock1); - cv.wait(fdLock1); + cv.wait(cvLock); assert(pair1.fd[1] == pair1.fd[0] + 1); @@ -268,30 +370,30 @@ int fakeSocketAccept4(int fd, int flags) std::unique_lock fdLock(pair.mutex[0]); fdsLock.unlock(); + std::unique_lock cvLock(cvMutex); + fdLock.unlock(); + while (pair.connectingFd == -1) - // pair.cv->wait(fdLock); - cv.wait(fdLock); + cv.wait(cvLock); assert(pair.connectingFd >= 0 && pair.connectingFd/2 < fds.size()); - FakeSocketPair& pair1 = fds[pair.connectingFd/2]; + FakeSocketPair& pair2 = fds[pair.connectingFd/2]; - std::unique_lock fdLock1(pair1.mutex[0]); + std::unique_lock fdLock1(pair2.mutex[0]); - assert(pair1.fd[1] == -1); - assert(pair1.fd[0] == pair.connectingFd); + assert(pair2.fd[1] == -1); + assert(pair2.fd[0] == pair.connectingFd); pair.connectingFd = -1; - fdLock.unlock(); - pair1.fd[1] = pair1.fd[0] + 1; + pair2.fd[1] = pair2.fd[0] + 1; - // pair1.cv->notify_one(); cv.notify_one(); - std::cerr << "+++++ Accept fd " << fd << ": " << pair1.fd[1] << "\n"; + std::cerr << "+++++ Accept fd " << fd << ": " << pair2.fd[1] << "\n"; - return pair1.fd[1]; + return pair2.fd[1]; } int fakeSocketPeer(int fd) @@ -333,6 +435,15 @@ ssize_t fakeSocketAvailableDataLength(int fd) // K: for this fd const int K = (fd&1); + if (!pair.readable[K]) + { + std::cerr << "+++++ EAGAIN: Available data on fd " << fd << "\n"; + errno = EAGAIN; + return -1; + } + + std::cerr << "+++++ Available data on fd " << fd << ": " << pair.buffer[K].size() << "\n"; + return pair.buffer[K].size(); } @@ -354,6 +465,8 @@ ssize_t fakeSocketRead(int fd, void *buf, size_t nbytes) // K: for this fd const int K = (fd&1); + // N: for its peer + const int N = 1 - K; if (pair.fd[K] == -1) { @@ -362,14 +475,14 @@ ssize_t fakeSocketRead(int fd, void *buf, size_t nbytes) return -1; } - if (pair.buffer[K].size() == 0) + if (!pair.readable[K]) { std::cerr << "+++++ EAGAIN: Read from fd " << fd << ", " << nbytes << (nbytes == 1 ? " byte" : " bytes") << "\n"; errno = EAGAIN; return -1; } - // These sockets are record-oriented! + // These sockets are record-oriented! It won't work to read less than the whole buffer. ssize_t result = pair.buffer[K].size(); if (nbytes < result) { @@ -380,8 +493,12 @@ ssize_t fakeSocketRead(int fd, void *buf, size_t nbytes) memmove(buf, pair.buffer[K].data(), result); pair.buffer[K].resize(0); + // If peer is closed, we continue to be readable + if (pair.fd[N] == -1) + pair.readable[K] = true; + else + pair.readable[K] = false; - // pair.cv->notify_one(); cv.notify_one(); std::cerr << "+++++ Read from fd " << fd << ": " << result << (result == 1 ? " byte" : " bytes") << "\n"; @@ -415,7 +532,7 @@ ssize_t fakeSocketFeed(int fd, const void *buf, size_t nbytes) return -1; } - if (pair.buffer[K].size() != 0) + if (pair.readable[K]) { std::cerr << "+++++ EAGAIN: Feed to fd " << fd << ", " << nbytes << (nbytes == 1 ? " byte" : " bytes") << "\n"; errno = EAGAIN; @@ -424,8 +541,8 @@ ssize_t fakeSocketFeed(int fd, const void *buf, size_t nbytes) pair.buffer[K].resize(nbytes); memmove(pair.buffer[K].data(), buf, nbytes); + pair.readable[K] = true; - // pair.cv->notify_one(); cv.notify_one(); std::cerr << "+++++ Feed to fd " << fd << ": " << nbytes << (nbytes == 1 ? " byte" : " bytes") << "\n"; @@ -461,7 +578,7 @@ ssize_t fakeSocketWrite(int fd, const void *buf, size_t nbytes) return -1; } - if (pair.buffer[N].size() != 0) + if (pair.readable[N]) { std::cerr << "+++++ EAGAIN: Write to fd " << fd << ", " << nbytes << (nbytes == 1 ? " byte" : " bytes") << "\n"; errno = EAGAIN; @@ -470,8 +587,8 @@ ssize_t fakeSocketWrite(int fd, const void *buf, size_t nbytes) pair.buffer[N].resize(nbytes); memmove(pair.buffer[N].data(), buf, nbytes); + pair.readable[N] = true; - // pair.cv->notify_one(); cv.notify_one(); std::cerr << "+++++ Write to fd " << fd << ": " << nbytes << (nbytes == 1 ? " byte" : " bytes") << "\n"; @@ -494,9 +611,16 @@ int fakeSocketClose(int fd) std::unique_lock fdLock(pair.mutex[0]); fdsLock.unlock(); - assert(pair.fd[fd&1] == fd); + const int K = (fd&1); + const int N = 1 - K; - pair.fd[fd&1] = -1; + assert(pair.fd[K] == fd); + + pair.fd[K] = -1; + pair.buffer[K].resize(0); + pair.readable[N] = true; + + cv.notify_one(); std::cerr << "+++++ Close fd " << fd << "\n";