diff --git a/src/Server.php b/src/Server.php index e081fb62..e24bf3a6 100644 --- a/src/Server.php +++ b/src/Server.php @@ -242,6 +242,16 @@ public function handleRequest(ConnectionInterface $conn, RequestInterface $reque $stream = new LengthLimitedStream($stream, $contentLength); } + $upgradeRequest = false; + if ($request->getProtocolVersion() !== '1.0' && $request->hasHeader('Connection') && strtolower($request->getHeaderLine('Connection')) === "upgrade") { + if (!$request->hasHeader('Upgrade') || $request->getHeaderLine('Upgrade') === '') { + // MUST have Upgrade options + $this->emit('error', array(new \InvalidArgumentException('Connection upgrade must specify upgrade protocol.'))); + return $this->writeError($conn, 400, $request); + } + $upgradeRequest = true; + } + $request = $request->withBody(new HttpBodyStream($stream, $contentLength)); if ($request->getProtocolVersion() !== '1.0' && '100-continue' === strtolower($request->getHeaderLine('Expect'))) { @@ -261,7 +271,7 @@ public function handleRequest(ConnectionInterface $conn, RequestInterface $reque $that = $this; $promise->then( - function ($response) use ($that, $conn, $request) { + function ($response) use ($that, $conn, $request, $contentLength, $stream, $upgradeRequest) { if (!$response instanceof ResponseInterface) { $message = 'The response callback is expected to resolve with an object implementing Psr\Http\Message\ResponseInterface, but resolved with "%s" instead.'; $message = sprintf($message, is_object($response) ? get_class($response) : gettype($response)); @@ -270,6 +280,71 @@ function ($response) use ($that, $conn, $request) { $that->emit('error', array($exception)); return $that->writeError($conn, 500, $request); } + + if ($response->getStatusCode() === 426) { + if (!$response->hasHeader('Upgrade') || $response->getHeaderLine('Upgrade') === '') { + $message = 'HTTP 1.1 426 response requires `Upgrade` header.'; + $exception = new \RuntimeException($message); + + $that->emit('error', array($exception)); + return $that->writeError($conn, 500, $request); + } + } + + $upgradeConnection = false; + if ($response->getStatusCode() === 101) { + if (!$upgradeRequest) { + $message = 'HTTP status 101 is not valid when no upgrade was requested'; + $exception = new \RuntimeException($message); + + $that->emit('error', array($exception)); + return $that->writeError($conn, 500, $request); + } + + if ($response->getProtocolVersion() === '1.0') { + $message = 'HTTP status 101 is not valid with protocol version 1.0'; + $exception = new \RuntimeException($message); + + $that->emit('error', array($exception)); + return $that->writeError($conn, 500, $request); + } + + if (!$response->hasHeader('Connection') || strtolower($response->getHeaderLine('Connection')) !== 'upgrade') { + $message = 'HTTP 1.1 Upgrade requires `Connection: upgrade` header.'; + $exception = new \RuntimeException($message); + + $that->emit('error', array($exception)); + return $that->writeError($conn, 500, $request); + } + + if (!$response->hasHeader('Upgrade') || $response->getHeaderLine('Upgrade') === '') { + $message = 'HTTP 1.1 Upgrade requires `Upgrade` header with exactly one protocol specified.'; + $exception = new \RuntimeException($message); + + $that->emit('error', array($exception)); + return $that->writeError($conn, 500, $request); + } + + $requestedProtocols = explode(',', preg_replace('/\s+/', '', $request->getHeaderLine('Upgrade'))); + + if (!in_array(trim($response->getHeaderLine('Upgrade')), $requestedProtocols)) { + $message = 'Upgrade requires response protocol to be one of the `Upgrade` protocols specified by the request.'; + $exception = new \RuntimeException($message); + + $that->emit('error', array($exception)); + return $that->writeError($conn, 500, $request); + } + + $upgradeConnection = true; + } + + if (!$upgradeConnection && $contentLength === 0) { + // If Body is empty or Content-Length is 0 and won't emit further data, + // 'data' events from other streams won't be called anymore + $stream->emit('end'); + $stream->close(); + } + $that->handleResponse($conn, $request, $response); }, function ($error) use ($that, $conn, $request) { @@ -281,13 +356,6 @@ function ($error) use ($that, $conn, $request) { return $that->writeError($conn, 500, $request); } ); - - if ($contentLength === 0) { - // If Body is empty or Content-Length is 0 and won't emit further data, - // 'data' events from other streams won't be called anymore - $stream->emit('end'); - $stream->close(); - } } /** @internal */ @@ -349,7 +417,7 @@ public function handleResponse(ConnectionInterface $connection, RequestInterface // HTTP/1.1 assumes persistent connection support by default // we do not support persistent connections, so let the client know - if ($request->getProtocolVersion() === '1.1') { + if ($request->getProtocolVersion() === '1.1' && $response->getStatusCode() !== 101) { $response = $response->withHeader('Connection', 'close'); } @@ -359,9 +427,12 @@ public function handleResponse(ConnectionInterface $connection, RequestInterface $response = $response->withoutHeader('Content-Length')->withoutHeader('Transfer-Encoding'); } - // response to HEAD and 1xx, 204 and 304 responses MUST NOT include a body - if ($request->getMethod() === 'HEAD' || ($code >= 100 && $code < 200) || $code === 204 || $code === 304) { - $response = $response->withBody(Psr7Implementation\stream_for('')); + // 101 response (Upgrade) should hold onto the body + if ($code !== 101) { + // response to HEAD and 1xx, 204 and 304 responses MUST NOT include a body + if ($request->getMethod() === 'HEAD' || ($code >= 100 && $code < 200) || $code === 204 || $code === 304) { + $response = $response->withBody(Psr7Implementation\stream_for('')); + } } $this->handleResponseBody($response, $connection); diff --git a/tests/ServerTest.php b/tests/ServerTest.php index b7ddac2f..22a35457 100644 --- a/tests/ServerTest.php +++ b/tests/ServerTest.php @@ -2229,6 +2229,252 @@ function ($data) use (&$buffer) { $this->assertInstanceOf('RuntimeException', $exception); } + private function getUpgradeHeader() + { + $data = "GET / HTTP/1.1\r\n"; + $data .= "Host: localhost\r\n"; + $data .= "Connection: Upgrade\r\n"; + $data .= "Upgrade: echo\r\n\r\n"; + + return $data; + } + + public function testConnectionUpgradeEcho() + { + $that = $this; + $server = new Server($this->socket, function (RequestInterface $request) use ($that) { + $responseStream = new ReadableStream(); + $request->getBody()->on('data', function ($data) use ($responseStream) { + $responseStream->emit('data', array($data)); + }); + + $that->assertEquals('Upgrade', $request->getHeaderLine('Connection')); + $that->assertEquals('echo', $request->getHeaderLine('Upgrade')); + + $response = new Response( + 101, + array( + 'Connection' => 'Upgrade', + 'Upgrade' => 'echo' + ), + $responseStream); + return $response; + }); + + $buffer = ''; + $this->connection + ->expects($this->any()) + ->method('write') + ->will( + $this->returnCallback( + function ($data) use (&$buffer) { + $buffer .= $data; + } + ) + ); + + $this->socket->emit('connection', array($this->connection)); + + $this->connection->emit('data', array($this->getUpgradeHeader())); + + $this->connection->emit('data', array('text to be echoed')); + + $this->assertStringStartsWith("HTTP/1.1 101 Switching Protocols\r\n", $buffer); + $this->assertContains("\r\nConnection: Upgrade\r\n", $buffer); + $this->assertContains("\r\nUpgrade: echo\r\n", $buffer); + $this->assertStringEndsWith("\r\n\r\ntext to be echoed", $buffer); + } + + public function testUpgradeWithNoProtocolRespondsWithError() + { + $that = $this; + $server = new Server($this->socket, function (RequestInterface $request) use ($that) { + $that->fail('Callback should not be called'); + }); + + $exception = null; + $server->on('error', function (\Exception $ex) use (&$exception) { + $exception = $ex; + }); + + $buffer = ''; + $this->connection + ->expects($this->any()) + ->method('write') + ->will( + $this->returnCallback( + function ($data) use (&$buffer) { + $buffer .= $data; + } + ) + ); + + $this->socket->emit('connection', array($this->connection)); + + $data = "GET / HTTP/1.1\r\n"; + $data .= "Host: localhost\r\n"; + $data .= "Connection: Upgrade\r\n\r\n"; + + $this->connection->emit('data', array($this->getUpgradeHeader())); + + $this->assertStringStartsWith("HTTP/1.1 500 Internal Server Error\r\n", $buffer); + $this->assertInstanceOf('RuntimeException', $exception); + } + + public function testUpgrade101MustContainUpgradeHeaderWithNewProtocol() + { + $that = $this; + $server = new Server($this->socket, function (RequestInterface $request) use ($that) { + $responseStream = new ReadableStream(); + $that->assertEquals('Upgrade', $request->getHeaderLine('Connection')); + $that->assertEquals('echo', $request->getHeaderLine('Upgrade')); + + $response = new Response( + 101, + array( + 'Connection' => 'Upgrade' + ), + $responseStream); + return $response; + }); + + $exception = null; + $server->on('error', function (\Exception $ex) use (&$exception) { + $exception = $ex; + }); + + $buffer = ''; + $this->connection + ->expects($this->any()) + ->method('write') + ->will( + $this->returnCallback( + function ($data) use (&$buffer) { + $buffer .= $data; + } + ) + ); + + $this->socket->emit('connection', array($this->connection)); + + $this->connection->emit('data', array($this->getUpgradeHeader())); + + $this->assertStringStartsWith("HTTP/1.1 500 Internal Server Error\r\n", $buffer); + $this->assertInstanceOf('RuntimeException', $exception); + } + + public function testUpgradeProtocolMustBeOneRequested() + { + $that = $this; + $server = new Server($this->socket, function (RequestInterface $request) use ($that) { + $responseStream = new ReadableStream(); + $that->assertEquals('Upgrade', $request->getHeaderLine('Connection')); + $that->assertEquals('echo', $request->getHeaderLine('Upgrade')); + + $response = new Response( + 101, + array( + 'Connection' => 'Upgrade', + 'Upgrade' => 'notecho' + ), + $responseStream); + return $response; + }); + + $exception = null; + $server->on('error', function (\Exception $ex) use (&$exception) { + $exception = $ex; + }); + + $buffer = ''; + $this->connection + ->expects($this->any()) + ->method('write') + ->will( + $this->returnCallback( + function ($data) use (&$buffer) { + $buffer .= $data; + } + ) + ); + + $this->socket->emit('connection', array($this->connection)); + + $this->connection->emit('data', array($this->getUpgradeHeader())); + + $this->assertStringStartsWith("HTTP/1.1 500 Internal Server Error\r\n", $buffer); + $this->assertInstanceOf('RuntimeException', $exception); + } + + public function testUpgrade426WithUpgradeHeader() + { + $server = new Server($this->socket, function (RequestInterface $request) { + $response = new Response( + 426, + array( + 'Upgrade' => 'something' + )); + return $response; + }); + + $exception = null; + $server->on('error', function (\Exception $ex) use (&$exception) { + $exception = $ex; + }); + + $buffer = ''; + $this->connection + ->expects($this->any()) + ->method('write') + ->will( + $this->returnCallback( + function ($data) use (&$buffer) { + $buffer .= $data; + } + ) + ); + + $this->socket->emit('connection', array($this->connection)); + + $this->connection->emit('data', array($this->getUpgradeHeader())); + + $this->assertStringStartsWith("HTTP/1.1 426 Upgrade Required\r\n", $buffer); + } + + public function testUpgrade426MustContainUpgradeHeaderWithProtocol() + { + $server = new Server($this->socket, function (RequestInterface $request) { + $response = new Response( + 426, + array()); + return $response; + }); + + $exception = null; + $server->on('error', function (\Exception $ex) use (&$exception) { + $exception = $ex; + }); + + $buffer = ''; + $this->connection + ->expects($this->any()) + ->method('write') + ->will( + $this->returnCallback( + function ($data) use (&$buffer) { + $buffer .= $data; + } + ) + ); + + $this->socket->emit('connection', array($this->connection)); + + $this->connection->emit('data', array($this->getUpgradeHeader())); + + $this->assertStringStartsWith("HTTP/1.1 500 Internal Server Error\r\n", $buffer); + $this->assertInstanceOf('RuntimeException', $exception); + } + private function createGetRequest() { $data = "GET / HTTP/1.1\r\n";