diff --git a/core/Network/TLS/Core.hs b/core/Network/TLS/Core.hs index 553f641e2..dcf9f190d 100644 --- a/core/Network/TLS/Core.hs +++ b/core/Network/TLS/Core.hs @@ -82,15 +82,10 @@ sendData ctx dataToSend = liftIO (checkValid ctx) >> mapM_ sendDataChunk (L.toCh recvData :: MonadIO m => Context -> m B.ByteString recvData ctx = liftIO $ do checkValid ctx - E.catchJust safeHandleError_EOF - doRecv - (\() -> return B.empty) - where doRecv = do - pkt <- withReadLock ctx $ recvPacket ctx - either onError process pkt - - safeHandleError_EOF Error_EOF = Just () - safeHandleError_EOF _ = Nothing + pkt <- withReadLock ctx $ recvPacket ctx + either onError process pkt + where onError Error_EOF = -- Not really an error. + return B.empty onError err@(Error_Protocol (reason,fatal,desc)) = terminate err (if fatal then AlertLevel_Fatal else AlertLevel_Warning) desc reason diff --git a/core/Network/TLS/IO.hs b/core/Network/TLS/IO.hs index 72d5c9e0b..cd4b1f45c 100644 --- a/core/Network/TLS/IO.hs +++ b/core/Network/TLS/IO.hs @@ -34,15 +34,18 @@ checkValid ctx = do eofed <- ctxEOF ctx when eofed $ throwIO $ mkIOError eofErrorType "data" Nothing Nothing -readExact :: Context -> Int -> IO Bytes +readExact :: Context -> Int -> IO (Either TLSError Bytes) readExact ctx sz = do hdrbs <- contextRecv ctx sz - when (B.length hdrbs < sz) $ do - setEOF ctx - if B.null hdrbs - then throwCore Error_EOF - else throwCore (Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ (show $B.length hdrbs))) - return hdrbs + if B.length hdrbs == sz + then return $ Right hdrbs + else do + setEOF ctx + return . Left $ + if B.null hdrbs + then Error_EOF + else Error_Packet ("partial packet: expecting " ++ show sz ++ " bytes, got: " ++ (show $B.length hdrbs)) + -- | recvRecord receive a full TLS record (header + data), from the other side. -- @@ -52,24 +55,32 @@ recvRecord :: Bool -- ^ flag to enable SSLv2 compat ClientHello reception -> IO (Either TLSError (Record Plaintext)) recvRecord compatSSLv2 ctx #ifdef SSLV2_COMPATIBLE - | compatSSLv2 = do - header <- readExact ctx 2 - if B.head header < 0x80 - then readExact ctx 3 >>= either (return . Left) recvLength . decodeHeader . B.append header - else either (return . Left) recvDeprecatedLength $ decodeDeprecatedHeaderLength header + | compatSSLv2 = readExact ctx 2 >>= either (return . Left) sslv2Header #endif - | otherwise = readExact ctx 5 >>= either (return . Left) recvLength . decodeHeader - where recvLength header@(Header _ _ readlen) + | otherwise = readExact ctx 5 >>= either (return . Left) (recvLengthE . decodeHeader) + + where recvLengthE = either (return . Left) recvLength + + recvLength header@(Header _ _ readlen) | readlen > 16384 + 2048 = return $ Left maximumSizeExceeded - | otherwise = readExact ctx (fromIntegral readlen) >>= getRecord header + | otherwise = + readExact ctx (fromIntegral readlen) >>= + either (return . Left) (getRecord header) #ifdef SSLV2_COMPATIBLE + sslv2Header header = + if B.head header >= 0x80 + then either (return . Left) recvDeprecatedLength $ decodeDeprecatedHeaderLength header + else readExact ctx 3 >>= + either (return . Left) (recvLengthE . decodeHeader . B.append header) + recvDeprecatedLength readlen | readlen > 1024 * 4 = return $ Left maximumSizeExceeded | otherwise = do - content <- readExact ctx (fromIntegral readlen) - case decodeDeprecatedHeader readlen content of - Left err -> return $ Left err - Right header -> getRecord header content + res <- readExact ctx (fromIntegral readlen) + case res of + Left e -> return $ Left e + Right content -> + either (return . Left) (flip getRecord content) $ decodeDeprecatedHeader readlen content #endif maximumSizeExceeded = Error_Protocol ("record exceeding maximum size", True, RecordOverflow) getRecord :: Header -> Bytes -> IO (Either TLSError (Record Plaintext))