@@ -4,19 +4,23 @@ module Database.PostgreSQL.Driver.RawConnection
44 , createRawConnection
55 ) where
66
7- import Control.Monad (void )
7+ import Control.Monad (void , when )
88import Control.Exception (bracketOnError , try )
99import Safe (headMay )
1010import Data.Monoid ((<>) )
11+ import Foreign (castPtr , plusPtr )
1112import System.Socket (socket , AddressInfo (.. ), getAddressInfo , socketAddress ,
1213 aiV4Mapped , AddressInfoException , Socket , connect ,
1314 close , receive , send )
15+ import System.Socket.Unsafe (unsafeReceive )
1416import System.Socket.Family.Inet (Inet )
1517import System.Socket.Type.Stream (Stream , sendAll )
1618import System.Socket.Protocol.TCP (TCP )
1719import System.Socket.Protocol.Default (Default )
1820import System.Socket.Family.Unix (Unix , socketAddressUnixPath )
1921import qualified Data.ByteString as B
22+ import qualified Data.ByteString.Internal as B
23+ import qualified Data.ByteString.Unsafe as B
2024import qualified Data.ByteString.Char8 as BS (pack )
2125
2226import Database.PostgreSQL.Driver.Error
@@ -27,7 +31,8 @@ data RawConnection = RawConnection
2731 { rFlush :: IO ()
2832 , rClose :: IO ()
2933 , rSend :: B. ByteString -> IO ()
30- , rReceive :: Int -> IO B. ByteString
34+ -- ByteString that should be prepended to received ByteString
35+ , rReceive :: B. ByteString -> Int -> IO B. ByteString
3136 }
3237
3338defaultUnixPathDirectory :: B. ByteString
@@ -75,6 +80,17 @@ constructRawConnection s = RawConnection
7580 { rFlush = pure ()
7681 , rClose = close s
7782 , rSend = \ msg -> void $ sendAll s msg mempty
78- , rReceive = \ n -> receive s n mempty
83+ , rReceive = rawReceive s
7984 }
8085
86+ {-# INLINE rawReceive #-}
87+ rawReceive :: Socket f Stream p -> B. ByteString -> Int -> IO B. ByteString
88+ rawReceive s bs n = B. unsafeUseAsCStringLen bs $ \ (prevPtr, prevLen) ->
89+ let bufSize = prevLen + n
90+ in B. createUptoN bufSize $ \ bufPtr -> do
91+ B. memcpy bufPtr (castPtr prevPtr) prevLen
92+ len <- unsafeReceive s (bufPtr `plusPtr` prevLen)
93+ (fromIntegral bufSize) mempty
94+ -- Received empty string means closed connection by the remote host
95+ when (len == 0 ) throwClosedException
96+ pure $ prevLen + fromIntegral len
0 commit comments