Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Database/PostgreSQL/Driver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ module Database.PostgreSQL.Driver
, sendBatchAndSync
, sendBatchAndFlush
, readNextData
, readAllData
, waitReadyForQuery
, sendSimpleQuery
, describeStatement
, findFirstError
, findAllErrors
-- * Errors
, Error(..)
, AuthError(..)
Expand Down
77 changes: 30 additions & 47 deletions src/Database/PostgreSQL/Driver/Connection.hs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module Database.PostgreSQL.Driver.Connection
( -- * Connection types
AbsConnection(..)
, Connection
, ConnectionCommon
Connection(..)
, ServerMessageFilter
, NotificationHandler
-- * Connection parameters
Expand All @@ -11,8 +9,7 @@ module Database.PostgreSQL.Driver.Connection
, getIntegerDatetimes
-- * Work with connection
, connect
, connectCommon
, connectCommon'
, connect'
, sendStartMessage
, sendMessage
, sendEncode
Expand Down Expand Up @@ -53,21 +50,18 @@ import Database.PostgreSQL.Driver.Settings
import Database.PostgreSQL.Driver.StatementStorage
import Database.PostgreSQL.Driver.RawConnection

type InChan = TQueue (Either ReceiverException ServerMessage)

-- | Public
-- Connection parametrized by message type in chan.
data AbsConnection mt = AbsConnection
data Connection = Connection
{ connRawConnection :: !RawConnection
, connReceiverThread :: !(Weak ThreadId)
, connStatementStorage :: !StatementStorage
, connParameters :: !ConnectionParameters
, connOutChan :: !(TQueue (Either ReceiverException mt))
, connChan :: !InChan
}

type Connection = AbsConnection DataMessage
type ConnectionCommon = AbsConnection ServerMessage

type InDataChan = TQueue (Either ReceiverException DataMessage)
type InAllChan = TQueue (Either ReceiverException ServerMessage)

type ServerMessageFilter = ServerMessage -> Bool
type NotificationHandler = Notification -> IO ()
Expand All @@ -87,37 +81,32 @@ data ConnectionParameters = ConnectionParameters
-- Getting information about connection

-- | Returns a server version of the current connection.
getServerVersion :: AbsConnection c -> ServerVersion
getServerVersion :: Connection -> ServerVersion
getServerVersion = paramServerVersion . connParameters

-- | Returns a server encoding of the current connection.
getServerEncoding :: AbsConnection c -> B.ByteString
getServerEncoding :: Connection -> B.ByteString
getServerEncoding = paramServerEncoding . connParameters

-- | Returns whether server uses integer datetimes.
getIntegerDatetimes :: AbsConnection c -> Bool
getIntegerDatetimes :: Connection -> Bool
getIntegerDatetimes = paramIntegerDatetimes . connParameters

-- | Public
connect :: ConnectionSettings -> IO (Either Error Connection)
connect settings = connectWith settings $ \rawConn params ->
buildConnection rawConn params
(receiverThread rawConn)

connectCommon
connect
:: ConnectionSettings
-> IO (Either Error ConnectionCommon)
connectCommon settings = connectCommon' settings defaultFilter
-> IO (Either Error Connection)
connect settings = connect' settings defaultFilter

-- | Like 'connectCommon', but allows specify a message filter.
-- | Like 'connect', but allows specify a message filter.
-- Useful for testing.
connectCommon'
connect'
:: ConnectionSettings
-> ServerMessageFilter
-> IO (Either Error ConnectionCommon)
connectCommon' settings msgFilter = connectWith settings $ \rawConn params ->
-> IO (Either Error Connection)
connect' settings msgFilter = connectWith settings $ \rawConn params ->
buildConnection rawConn params
(\chan -> receiverThreadCommon rawConn chan
(\chan -> receiverThread rawConn chan
msgFilter defaultNotificationHandler)

-- Low-level sending functions
Expand All @@ -134,13 +123,13 @@ sendMessage rawConn msg = void $
rSend rawConn . runEncode $ encodeClientMessage msg

{-# INLINE sendEncode #-}
sendEncode :: AbsConnection c -> Encode -> IO ()
sendEncode :: Connection -> Encode -> IO ()
sendEncode conn = void . rSend (connRawConnection conn) . runEncode

connectWith
:: ConnectionSettings
-> (RawConnection -> ConnectionParameters -> IO (AbsConnection c))
-> IO (Either Error (AbsConnection c))
-> (RawConnection -> ConnectionParameters -> IO Connection)
-> IO (Either Error Connection)
connectWith settings buildAction =
bracketOnError
(createRawConnection settings)
Expand Down Expand Up @@ -199,8 +188,8 @@ buildConnection
:: RawConnection
-> ConnectionParameters
-- action in receiver thread
-> (TQueue (Either ReceiverException c) -> IO ())
-> IO (AbsConnection c)
-> (InChan -> IO ())
-> IO Connection
buildConnection rawConn connParams receiverAction = do
chan <- newTQueueIO
storage <- newStatementStorage
Expand All @@ -215,12 +204,12 @@ buildConnection rawConn connParams receiverAction = do
labelThread tid "postgres-wire receiver"
weakTid <- mkWeakThreadId tid

pure AbsConnection
pure Connection
{ connRawConnection = rawConn
, connReceiverThread = weakTid
, connStatementStorage = storage
, connParameters = connParams
, connOutChan = chan
, connChan = chan
}

-- | Parses connection parameters.
Expand Down Expand Up @@ -257,7 +246,7 @@ handshakeTls :: RawConnection -> IO ()
handshakeTls _ = pure ()

-- | Closes connection. Does not throw exceptions when socket is closed.
close :: AbsConnection c -> IO ()
close :: Connection -> IO ()
close conn = do
maybe (pure ()) killThread =<< deRefWeak (connReceiverThread conn)
sendMessage (connRawConnection conn) Terminate `catch` handlerEx
Expand All @@ -267,25 +256,19 @@ close conn = do
| otherwise = throwIO e

-- | Any exception prevents thread from future work.
receiverThread :: RawConnection -> InDataChan -> IO ()
receiverThread rawConn dataChan = loopExtractDataRows
(\bs -> rReceive rawConn bs 4096)
(writeChan dataChan . Right)

-- | Any exception prevents thread from future work.
receiverThreadCommon
receiverThread
:: RawConnection
-> InAllChan
-> InChan
-> ServerMessageFilter
-> NotificationHandler
-> IO ()
receiverThreadCommon rawConn chan msgFilter ntfHandler = go ""
receiverThread rawConn chan msgFilter ntfHandler = go ""
where
go bs = do
(rest, msg) <- decodeNextServerMessage bs readMoreAction
handler msg >> go rest

readMoreAction = (\bs -> rReceive rawConn bs 4096)
readMoreAction bs = rReceive rawConn bs 4096
handler msg = do
dispatchIfNotification msg ntfHandler
when (msgFilter msg) $ writeChan chan $ Right msg
Expand Down Expand Up @@ -316,7 +299,7 @@ defaultFilter msg = case msg of
-- messages affecting data handled in dispatcher
CommandComplete{} -> False
-- messages affecting data handled in dispatcher
DataRow{} -> False
DataRow{} -> True
-- messages affecting data handled in dispatcher
EmptyQueryResponse -> False
-- We need collect all errors to know whether the whole command is successful
Expand Down
45 changes: 27 additions & 18 deletions src/Database/PostgreSQL/Driver/Query.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ module Database.PostgreSQL.Driver.Query
, sendBatchAndSync
, sendSync
, readNextData
, readAllData
, waitReadyForQuery
-- * Connection common
, sendSimpleQuery
, describeStatement
, collectUntilReadyForQuery
, findFirstError
, findAllErrors
) where

import Control.Concurrent.STM.TQueue (TQueue, readTQueue )
Expand Down Expand Up @@ -54,31 +57,34 @@ sendSync conn = sendEncode conn $ encodeClientMessage Sync

-- | Public
{-# INLINABLE readNextData #-}
readNextData :: Connection -> IO (Either Error DataRows)
readNextData :: Connection -> IO (Either Error ByteString)
readNextData conn =
readChan (connOutChan conn) >>=
readChan (connChan conn) >>=
either (pure . Left . ReceiverError) handleDataMessage
where
handleDataMessage msg = case msg of
(DataError e) -> pure . Left $ PostgresError e
(DataMessage rows) -> pure . Right $ rows
DataReady -> throwIncorrectUsage
"Expected DataRow message, but got ReadyForQuery"
DataRow bs -> pure . Right $ bs
ReadyForQuery _ -> throwIncorrectUsage "Expected DataRow, but got ReadyForQuery"
_ -> readNextData conn

{-# INLINABLE readAllData #-}
readAllData :: Connection -> IO (Either Error [ByteString])
readAllData conn = do
msgs <- collectUntilReadyForQuery conn
return $ msgs >>= (\msgs -> return [bs | DataRow bs <- msgs])

{-# INLINABLE waitReadyForQuery #-}
waitReadyForQuery :: Connection -> IO (Either Error ())
waitReadyForQuery conn =
readChan (connOutChan conn) >>=
readChan (connChan conn) >>=
either (pure . Left . ReceiverError) handleDataMessage
where
handleDataMessage msg = case msg of
(DataError e) -> do
-- We should wait for ReadyForQuery anyway.
waitReadyForQuery conn
pure . Left $ PostgresError e
(DataMessage _) -> throwIncorrectUsage
"Expected ReadyForQuery, but got DataRow message"
DataReady -> pure $ Right ()
ReadyForQuery _ -> pure $ Right ()
DataRow _ -> throwIncorrectUsage "Expected ReadyForQuery, but got DataRow message"
ErrorResponse e -> do
waitReadyForQuery conn
pure . Left $ PostgresError e

-- Helper
{-# INLINE sendBatchEndBy #-}
Expand Down Expand Up @@ -115,7 +121,7 @@ constructBatch conn = fmap fold . traverse constructSingle
pure $ parseMessage <> bindMessage <> executeMessage

-- | Public
sendSimpleQuery :: ConnectionCommon -> ByteString -> IO (Either Error ())
sendSimpleQuery :: Connection -> ByteString -> IO (Either Error ())
sendSimpleQuery conn q = do
sendMessage (connRawConnection conn) $ SimpleQuery (StatementSQL q)
(checkErrors =<<) <$> collectUntilReadyForQuery conn
Expand All @@ -125,7 +131,7 @@ sendSimpleQuery conn q = do

-- | Public
describeStatement
:: ConnectionCommon
:: Connection
-> ByteString
-> IO (Either Error (Vector Oid, Vector FieldDescription))
describeStatement conn stmt = do
Expand All @@ -149,10 +155,10 @@ describeStatement conn stmt = do

-- Collects all messages preceding `ReadyForQuery`.
collectUntilReadyForQuery
:: ConnectionCommon
:: Connection
-> IO (Either Error [ServerMessage])
collectUntilReadyForQuery conn = do
msg <- readChan $ connOutChan conn
msg <- readChan $ connChan conn
case msg of
Left e -> pure $ Left $ ReceiverError e
Right ReadyForQuery{} -> pure $ Right []
Expand All @@ -164,6 +170,9 @@ findFirstError [] = Nothing
findFirstError (ErrorResponse desc : _) = Just desc
findFirstError (_ : xs) = findFirstError xs

findAllErrors :: [ServerMessage] -> [ErrorDesc]
findAllErrors msgs = [e | ErrorResponse e <- msgs]

{-# INLINE readChan #-}
readChan :: TQueue a -> IO a
readChan = atomically . readTQueue
4 changes: 1 addition & 3 deletions src/Database/PostgreSQL/Protocol/Codecs/Decoders.hs
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@ import Database.PostgreSQL.Protocol.Codecs.Time
import Database.PostgreSQL.Protocol.Codecs.Numeric

-- | Decodes DataRow header.
-- 1 byte - Message Header
-- 4 bytes - Message length
-- 2 bytes - count of columns in the DataRow
{-# INLINE dataRowHeader #-}
dataRowHeader :: Decode ()
dataRowHeader = skipBytes 7
dataRowHeader = skipBytes 2

{-# INLINE fieldLength #-}
fieldLength :: Decode Int
Expand Down
41 changes: 21 additions & 20 deletions src/Database/PostgreSQL/Protocol/DataRows.hs
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,25 @@ loopExtractDataRows readMoreAction callback = go "" Empty

-- It is better that Decode throws exception on invalid input
{-# INLINABLE decodeOneRow #-}
decodeOneRow :: Decode a -> DataRows -> a
decodeOneRow dec Empty = snd $ runDecode dec ""
decodeOneRow dec (DataRows (DataChunk _ bs) _) = snd $ runDecode dec bs
decodeOneRow :: Decode a -> B.ByteString -> a
decodeOneRow dec bs = snd $ runDecode dec bs

{-# INLINABLE decodeManyRows #-}
decodeManyRows :: Decode a -> DataRows -> V.Vector a
decodeManyRows dec dr = unsafePerformIO $ do
vec <- MV.unsafeNew . fromIntegral $ countDataRows dr
let go startInd Empty = pure ()
go startInd (DataRows (DataChunk len bs) nextDr) = do
let endInd = startInd + fromIntegral len
runDecodeIO
(traverse_ (writeDec vec) [startInd .. (endInd -1)])
bs
go endInd nextDr
go 0 dr
V.unsafeFreeze vec
where
{-# INLINE writeDec #-}
writeDec vec pos = dec >>= embedIO . MV.unsafeWrite vec pos
decodeManyRows :: Decode a -> [B.ByteString] -> [a]
decodeManyRows dec = map (decodeOneRow dec)
--vec <- MV.unsafeNew . fromIntegral $ countDataRows dr
--let go startInd Empty = pure ()
--go startInd (DataRows (DataChunk len bs) nextDr) = do
--let endInd = startInd + fromIntegral len
--runDecodeIO
--(traverse_ (writeDec vec) [startInd .. (endInd -1)])
--bs
--go endInd nextDr
--go 0 dr
--V.unsafeFreeze vec
--where
--[># INLINE writeDec #<]
--writeDec vec pos = dec >>= embedIO . MV.unsafeWrite vec pos

---
-- Utils
Expand All @@ -183,8 +182,10 @@ reverseDataRows :: DataRows -> DataRows
reverseDataRows = foldlDataRows (flip chunk) Empty

{-# INLINE countDataRows #-}
countDataRows :: DataRows -> Word
countDataRows = foldlDataRows (\acc (DataChunk c _) -> acc + c) 0
countDataRows :: [B.ByteString] -> Int
-- quote from docs: ... This will be followed by a DataRow message for each row being returned to the frontend.
-- So for each DataRow message we can be sure that ByteString payload contains encoded data for exactly one row.
countDataRows = length
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the logic changed there, are we guaranteed that there is one row per list element?


-- FIXME delete later
-- | For testing only
Expand Down
4 changes: 2 additions & 2 deletions src/Database/PostgreSQL/Protocol/Decoders.hs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ decodeServerMessage (Header c len) = case chr $ fromIntegral c of
>>= eitherToDecode . parseCommandResult)
-- Dont parse data rows here.
'D' -> do
_ <- getByteString len
pure DataRow
bs <- getByteString len
pure $! DataRow bs
'I' -> pure EmptyQueryResponse
'E' -> ErrorResponse <$>
(getByteString len >>=
Expand Down
3 changes: 1 addition & 2 deletions src/Database/PostgreSQL/Protocol/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ data ServerMessage
| BindComplete
| CloseComplete
| CommandComplete !CommandResult
-- DataRows lays in separate data type
| DataRow
| DataRow !ByteString
| EmptyQueryResponse
| ErrorResponse !ErrorDesc
| NoData
Expand Down
Loading