From ba990b41b1db2076e399e977c9132702b0357add Mon Sep 17 00:00:00 2001 From: qz Date: Sun, 3 Feb 2019 11:58:21 +0300 Subject: [PATCH 1/2] one connection type, DataRow now contains payload in ServerMessage --- src/Database/PostgreSQL/Driver.hs | 3 + src/Database/PostgreSQL/Driver/Connection.hs | 77 ++++++++----------- src/Database/PostgreSQL/Driver/Query.hs | 45 ++++++----- .../PostgreSQL/Protocol/Codecs/Decoders.hs | 4 +- src/Database/PostgreSQL/Protocol/DataRows.hs | 39 +++++----- src/Database/PostgreSQL/Protocol/Decoders.hs | 4 +- src/Database/PostgreSQL/Protocol/Types.hs | 3 +- tests/Codecs/QuickCheck.hs | 2 + tests/Connection.hs | 13 +--- tests/Driver.hs | 36 ++++----- tests/Fault.hs | 14 ++-- tests/Protocol.hs | 8 +- 12 files changed, 117 insertions(+), 131 deletions(-) diff --git a/src/Database/PostgreSQL/Driver.hs b/src/Database/PostgreSQL/Driver.hs index 4a00182..cd3baf7 100644 --- a/src/Database/PostgreSQL/Driver.hs +++ b/src/Database/PostgreSQL/Driver.hs @@ -20,9 +20,12 @@ module Database.PostgreSQL.Driver , sendBatchAndSync , sendBatchAndFlush , readNextData + , readAllData , waitReadyForQuery , sendSimpleQuery , describeStatement + , findFirstError + , findAllErrors -- * Errors , Error(..) , AuthError(..) diff --git a/src/Database/PostgreSQL/Driver/Connection.hs b/src/Database/PostgreSQL/Driver/Connection.hs index 4163cc4..01ae7dd 100644 --- a/src/Database/PostgreSQL/Driver/Connection.hs +++ b/src/Database/PostgreSQL/Driver/Connection.hs @@ -1,8 +1,6 @@ module Database.PostgreSQL.Driver.Connection ( -- * Connection types - AbsConnection(..) - , Connection - , ConnectionCommon + Connection(..) , ServerMessageFilter , NotificationHandler -- * Connection parameters @@ -11,8 +9,7 @@ module Database.PostgreSQL.Driver.Connection , getIntegerDatetimes -- * Work with connection , connect - , connectCommon - , connectCommon' + , connect' , sendStartMessage , sendMessage , sendEncode @@ -53,21 +50,18 @@ import Database.PostgreSQL.Driver.Settings import Database.PostgreSQL.Driver.StatementStorage import Database.PostgreSQL.Driver.RawConnection +type InChan = TQueue (Either ReceiverException ServerMessage) + -- | Public -- Connection parametrized by message type in chan. -data AbsConnection mt = AbsConnection +data Connection = Connection { connRawConnection :: !RawConnection , connReceiverThread :: !(Weak ThreadId) , connStatementStorage :: !StatementStorage , connParameters :: !ConnectionParameters - , connOutChan :: !(TQueue (Either ReceiverException mt)) + , connChan :: !InChan } -type Connection = AbsConnection DataMessage -type ConnectionCommon = AbsConnection ServerMessage - -type InDataChan = TQueue (Either ReceiverException DataMessage) -type InAllChan = TQueue (Either ReceiverException ServerMessage) type ServerMessageFilter = ServerMessage -> Bool type NotificationHandler = Notification -> IO () @@ -87,37 +81,32 @@ data ConnectionParameters = ConnectionParameters -- Getting information about connection -- | Returns a server version of the current connection. -getServerVersion :: AbsConnection c -> ServerVersion +getServerVersion :: Connection -> ServerVersion getServerVersion = paramServerVersion . connParameters -- | Returns a server encoding of the current connection. -getServerEncoding :: AbsConnection c -> B.ByteString +getServerEncoding :: Connection -> B.ByteString getServerEncoding = paramServerEncoding . connParameters -- | Returns whether server uses integer datetimes. -getIntegerDatetimes :: AbsConnection c -> Bool +getIntegerDatetimes :: Connection -> Bool getIntegerDatetimes = paramIntegerDatetimes . connParameters -- | Public -connect :: ConnectionSettings -> IO (Either Error Connection) -connect settings = connectWith settings $ \rawConn params -> - buildConnection rawConn params - (receiverThread rawConn) - -connectCommon +connect :: ConnectionSettings - -> IO (Either Error ConnectionCommon) -connectCommon settings = connectCommon' settings defaultFilter + -> IO (Either Error Connection) +connect settings = connect' settings defaultFilter --- | Like 'connectCommon', but allows specify a message filter. +-- | Like 'connect', but allows specify a message filter. -- Useful for testing. -connectCommon' +connect' :: ConnectionSettings -> ServerMessageFilter - -> IO (Either Error ConnectionCommon) -connectCommon' settings msgFilter = connectWith settings $ \rawConn params -> + -> IO (Either Error Connection) +connect' settings msgFilter = connectWith settings $ \rawConn params -> buildConnection rawConn params - (\chan -> receiverThreadCommon rawConn chan + (\chan -> receiverThread rawConn chan msgFilter defaultNotificationHandler) -- Low-level sending functions @@ -134,13 +123,13 @@ sendMessage rawConn msg = void $ rSend rawConn . runEncode $ encodeClientMessage msg {-# INLINE sendEncode #-} -sendEncode :: AbsConnection c -> Encode -> IO () +sendEncode :: Connection -> Encode -> IO () sendEncode conn = void . rSend (connRawConnection conn) . runEncode connectWith :: ConnectionSettings - -> (RawConnection -> ConnectionParameters -> IO (AbsConnection c)) - -> IO (Either Error (AbsConnection c)) + -> (RawConnection -> ConnectionParameters -> IO Connection) + -> IO (Either Error Connection) connectWith settings buildAction = bracketOnError (createRawConnection settings) @@ -199,8 +188,8 @@ buildConnection :: RawConnection -> ConnectionParameters -- action in receiver thread - -> (TQueue (Either ReceiverException c) -> IO ()) - -> IO (AbsConnection c) + -> (InChan -> IO ()) + -> IO Connection buildConnection rawConn connParams receiverAction = do chan <- newTQueueIO storage <- newStatementStorage @@ -215,12 +204,12 @@ buildConnection rawConn connParams receiverAction = do labelThread tid "postgres-wire receiver" weakTid <- mkWeakThreadId tid - pure AbsConnection + pure Connection { connRawConnection = rawConn , connReceiverThread = weakTid , connStatementStorage = storage , connParameters = connParams - , connOutChan = chan + , connChan = chan } -- | Parses connection parameters. @@ -257,7 +246,7 @@ handshakeTls :: RawConnection -> IO () handshakeTls _ = pure () -- | Closes connection. Does not throw exceptions when socket is closed. -close :: AbsConnection c -> IO () +close :: Connection -> IO () close conn = do maybe (pure ()) killThread =<< deRefWeak (connReceiverThread conn) sendMessage (connRawConnection conn) Terminate `catch` handlerEx @@ -267,25 +256,19 @@ close conn = do | otherwise = throwIO e -- | Any exception prevents thread from future work. -receiverThread :: RawConnection -> InDataChan -> IO () -receiverThread rawConn dataChan = loopExtractDataRows - (\bs -> rReceive rawConn bs 4096) - (writeChan dataChan . Right) - --- | Any exception prevents thread from future work. -receiverThreadCommon +receiverThread :: RawConnection - -> InAllChan + -> InChan -> ServerMessageFilter -> NotificationHandler -> IO () -receiverThreadCommon rawConn chan msgFilter ntfHandler = go "" +receiverThread rawConn chan msgFilter ntfHandler = go "" where go bs = do (rest, msg) <- decodeNextServerMessage bs readMoreAction handler msg >> go rest - readMoreAction = (\bs -> rReceive rawConn bs 4096) + readMoreAction bs = rReceive rawConn bs 4096 handler msg = do dispatchIfNotification msg ntfHandler when (msgFilter msg) $ writeChan chan $ Right msg @@ -316,7 +299,7 @@ defaultFilter msg = case msg of -- messages affecting data handled in dispatcher CommandComplete{} -> False -- messages affecting data handled in dispatcher - DataRow{} -> False + DataRow{} -> True -- messages affecting data handled in dispatcher EmptyQueryResponse -> False -- We need collect all errors to know whether the whole command is successful diff --git a/src/Database/PostgreSQL/Driver/Query.hs b/src/Database/PostgreSQL/Driver/Query.hs index fc1a934..4f26218 100644 --- a/src/Database/PostgreSQL/Driver/Query.hs +++ b/src/Database/PostgreSQL/Driver/Query.hs @@ -5,11 +5,14 @@ module Database.PostgreSQL.Driver.Query , sendBatchAndSync , sendSync , readNextData + , readAllData , waitReadyForQuery -- * Connection common , sendSimpleQuery , describeStatement , collectUntilReadyForQuery + , findFirstError + , findAllErrors ) where import Control.Concurrent.STM.TQueue (TQueue, readTQueue ) @@ -54,31 +57,34 @@ sendSync conn = sendEncode conn $ encodeClientMessage Sync -- | Public {-# INLINABLE readNextData #-} -readNextData :: Connection -> IO (Either Error DataRows) +readNextData :: Connection -> IO (Either Error ByteString) readNextData conn = - readChan (connOutChan conn) >>= + readChan (connChan conn) >>= either (pure . Left . ReceiverError) handleDataMessage where handleDataMessage msg = case msg of - (DataError e) -> pure . Left $ PostgresError e - (DataMessage rows) -> pure . Right $ rows - DataReady -> throwIncorrectUsage - "Expected DataRow message, but got ReadyForQuery" + DataRow bs -> pure . Right $ bs + ReadyForQuery _ -> throwIncorrectUsage "Expected DataRow, but got ReadyForQuery" + _ -> readNextData conn + +{-# INLINABLE readAllData #-} +readAllData :: Connection -> IO (Either Error [ByteString]) +readAllData conn = do + msgs <- collectUntilReadyForQuery conn + return $ msgs >>= (\msgs -> return [bs | DataRow bs <- msgs]) {-# INLINABLE waitReadyForQuery #-} waitReadyForQuery :: Connection -> IO (Either Error ()) waitReadyForQuery conn = - readChan (connOutChan conn) >>= + readChan (connChan conn) >>= either (pure . Left . ReceiverError) handleDataMessage where handleDataMessage msg = case msg of - (DataError e) -> do - -- We should wait for ReadyForQuery anyway. - waitReadyForQuery conn - pure . Left $ PostgresError e - (DataMessage _) -> throwIncorrectUsage - "Expected ReadyForQuery, but got DataRow message" - DataReady -> pure $ Right () + ReadyForQuery _ -> pure $ Right () + DataRow _ -> throwIncorrectUsage "Expected ReadyForQuery, but got DataRow message" + ErrorResponse e -> do + waitReadyForQuery conn + pure . Left $ PostgresError e -- Helper {-# INLINE sendBatchEndBy #-} @@ -115,7 +121,7 @@ constructBatch conn = fmap fold . traverse constructSingle pure $ parseMessage <> bindMessage <> executeMessage -- | Public -sendSimpleQuery :: ConnectionCommon -> ByteString -> IO (Either Error ()) +sendSimpleQuery :: Connection -> ByteString -> IO (Either Error ()) sendSimpleQuery conn q = do sendMessage (connRawConnection conn) $ SimpleQuery (StatementSQL q) (checkErrors =<<) <$> collectUntilReadyForQuery conn @@ -125,7 +131,7 @@ sendSimpleQuery conn q = do -- | Public describeStatement - :: ConnectionCommon + :: Connection -> ByteString -> IO (Either Error (Vector Oid, Vector FieldDescription)) describeStatement conn stmt = do @@ -149,10 +155,10 @@ describeStatement conn stmt = do -- Collects all messages preceding `ReadyForQuery`. collectUntilReadyForQuery - :: ConnectionCommon + :: Connection -> IO (Either Error [ServerMessage]) collectUntilReadyForQuery conn = do - msg <- readChan $ connOutChan conn + msg <- readChan $ connChan conn case msg of Left e -> pure $ Left $ ReceiverError e Right ReadyForQuery{} -> pure $ Right [] @@ -164,6 +170,9 @@ findFirstError [] = Nothing findFirstError (ErrorResponse desc : _) = Just desc findFirstError (_ : xs) = findFirstError xs +findAllErrors :: [ServerMessage] -> [ErrorDesc] +findAllErrors msgs = [e | ErrorResponse e <- msgs] + {-# INLINE readChan #-} readChan :: TQueue a -> IO a readChan = atomically . readTQueue diff --git a/src/Database/PostgreSQL/Protocol/Codecs/Decoders.hs b/src/Database/PostgreSQL/Protocol/Codecs/Decoders.hs index 09088e7..fedfbd7 100644 --- a/src/Database/PostgreSQL/Protocol/Codecs/Decoders.hs +++ b/src/Database/PostgreSQL/Protocol/Codecs/Decoders.hs @@ -40,12 +40,10 @@ import Database.PostgreSQL.Protocol.Codecs.Time import Database.PostgreSQL.Protocol.Codecs.Numeric -- | Decodes DataRow header. --- 1 byte - Message Header --- 4 bytes - Message length -- 2 bytes - count of columns in the DataRow {-# INLINE dataRowHeader #-} dataRowHeader :: Decode () -dataRowHeader = skipBytes 7 +dataRowHeader = skipBytes 2 {-# INLINE fieldLength #-} fieldLength :: Decode Int diff --git a/src/Database/PostgreSQL/Protocol/DataRows.hs b/src/Database/PostgreSQL/Protocol/DataRows.hs index 347cca1..b472272 100644 --- a/src/Database/PostgreSQL/Protocol/DataRows.hs +++ b/src/Database/PostgreSQL/Protocol/DataRows.hs @@ -140,26 +140,25 @@ loopExtractDataRows readMoreAction callback = go "" Empty -- It is better that Decode throws exception on invalid input {-# INLINABLE decodeOneRow #-} -decodeOneRow :: Decode a -> DataRows -> a -decodeOneRow dec Empty = snd $ runDecode dec "" -decodeOneRow dec (DataRows (DataChunk _ bs) _) = snd $ runDecode dec bs +decodeOneRow :: Decode a -> B.ByteString -> a +decodeOneRow dec bs = snd $ runDecode dec bs {-# INLINABLE decodeManyRows #-} -decodeManyRows :: Decode a -> DataRows -> V.Vector a -decodeManyRows dec dr = unsafePerformIO $ do - vec <- MV.unsafeNew . fromIntegral $ countDataRows dr - let go startInd Empty = pure () - go startInd (DataRows (DataChunk len bs) nextDr) = do - let endInd = startInd + fromIntegral len - runDecodeIO - (traverse_ (writeDec vec) [startInd .. (endInd -1)]) - bs - go endInd nextDr - go 0 dr - V.unsafeFreeze vec - where - {-# INLINE writeDec #-} - writeDec vec pos = dec >>= embedIO . MV.unsafeWrite vec pos +decodeManyRows :: Decode a -> [B.ByteString] -> [a] +decodeManyRows dec = map (decodeOneRow dec) + --vec <- MV.unsafeNew . fromIntegral $ countDataRows dr + --let go startInd Empty = pure () + --go startInd (DataRows (DataChunk len bs) nextDr) = do + --let endInd = startInd + fromIntegral len + --runDecodeIO + --(traverse_ (writeDec vec) [startInd .. (endInd -1)]) + --bs + --go endInd nextDr + --go 0 dr + --V.unsafeFreeze vec + --where + --[># INLINE writeDec #<] + --writeDec vec pos = dec >>= embedIO . MV.unsafeWrite vec pos --- -- Utils @@ -183,8 +182,8 @@ reverseDataRows :: DataRows -> DataRows reverseDataRows = foldlDataRows (flip chunk) Empty {-# INLINE countDataRows #-} -countDataRows :: DataRows -> Word -countDataRows = foldlDataRows (\acc (DataChunk c _) -> acc + c) 0 +countDataRows :: [B.ByteString] -> Int +countDataRows = length -- FIXME delete later -- | For testing only diff --git a/src/Database/PostgreSQL/Protocol/Decoders.hs b/src/Database/PostgreSQL/Protocol/Decoders.hs index 6d61877..2e184e2 100644 --- a/src/Database/PostgreSQL/Protocol/Decoders.hs +++ b/src/Database/PostgreSQL/Protocol/Decoders.hs @@ -77,8 +77,8 @@ decodeServerMessage (Header c len) = case chr $ fromIntegral c of >>= eitherToDecode . parseCommandResult) -- Dont parse data rows here. 'D' -> do - _ <- getByteString len - pure DataRow + bs <- getByteString len + pure $ DataRow bs 'I' -> pure EmptyQueryResponse 'E' -> ErrorResponse <$> (getByteString len >>= diff --git a/src/Database/PostgreSQL/Protocol/Types.hs b/src/Database/PostgreSQL/Protocol/Types.hs index 4abd259..b52962f 100644 --- a/src/Database/PostgreSQL/Protocol/Types.hs +++ b/src/Database/PostgreSQL/Protocol/Types.hs @@ -172,8 +172,7 @@ data ServerMessage | BindComplete | CloseComplete | CommandComplete !CommandResult - -- DataRows lays in separate data type - | DataRow + | DataRow !ByteString | EmptyQueryResponse | ErrorResponse !ErrorDesc | NoData diff --git a/tests/Codecs/QuickCheck.hs b/tests/Codecs/QuickCheck.hs index cb017ce..d931111 100644 --- a/tests/Codecs/QuickCheck.hs +++ b/tests/Codecs/QuickCheck.hs @@ -16,6 +16,7 @@ import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC import Database.PostgreSQL.Driver +import Database.PostgreSQL.Driver.Query import Database.PostgreSQL.Protocol.DataRows import Database.PostgreSQL.Protocol.Types import Database.PostgreSQL.Protocol.Store.Encode @@ -40,6 +41,7 @@ makeCodecProperty c oid encoder fd v = monadicIO $ do decoder = PD.dataRowHeader *> PD.getNonNullable fd r <- run $ do sendBatchAndSync c [q] + -- msgs <- collectUntilReadyForQuery c dr <- readNextData c waitReadyForQuery c either (error . show) (pure . decodeOneRow decoder) dr diff --git a/tests/Connection.hs b/tests/Connection.hs index 99aad98..eb97247 100644 --- a/tests/Connection.hs +++ b/tests/Connection.hs @@ -9,15 +9,10 @@ import Database.PostgreSQL.Driver.Settings withConnection :: (Connection -> IO a) -> IO a withConnection = bracket (getConnection <$> connect defaultSettings) close --- | Creates a common connection. -withConnectionCommon :: (ConnectionCommon -> IO a) -> IO a -withConnectionCommon = bracket - (getConnection <$> connectCommon defaultSettings) close - -- | Creates connection than collects all server messages in chan. -withConnectionCommonAll :: (ConnectionCommon -> IO a) -> IO a -withConnectionCommonAll = bracket - (getConnection <$> connectCommon' defaultSettings filterAllowedAll) close +withConnectionAll :: (Connection -> IO a) -> IO a +withConnectionAll = bracket + (getConnection <$> connect' defaultSettings filterAllowedAll) close defaultSettings = defaultConnectionSettings { settingsHost = "localhost" @@ -26,7 +21,7 @@ defaultSettings = defaultConnectionSettings , settingsPassword = "" } -getConnection :: Either Error (AbsConnection c)-> AbsConnection c +getConnection :: Either Error Connection -> Connection getConnection (Left e) = error $ "Connection error " ++ show e getConnection (Right c) = c diff --git a/tests/Driver.hs b/tests/Driver.hs index 8afe8ad..dee8081 100644 --- a/tests/Driver.hs +++ b/tests/Driver.hs @@ -58,10 +58,10 @@ fromRight :: Either e a -> a fromRight (Right v) = v fromRight _ = error "fromRight" -fromMessage :: Either e DataRows -> B.ByteString +fromMessage :: Either e B.ByteString -> B.ByteString -- TODO --- 5 bytes -header, 2 bytes -count, 4 bytes - length -fromMessage (Right rows) = B.drop 11 $ flattenDataRows rows +-- 2 bytes -count, 4 bytes - length +fromMessage (Right row) = B.drop 6 $ row fromMessage _ = error "from message" -- | Single batch. @@ -121,9 +121,8 @@ testQueryWithoutResult = assertQueryNoData $ assertQueryNoData :: Query -> IO () assertQueryNoData q = withConnection $ \c -> do sendBatchAndSync c [q] - r <- fromRight <$> readNextData c - waitReadyForQuery c - Empty @=? r + r <- fromRight <$> readAllData c + [] @=? r -- | Asserts that all the received data messages are in form (Right _) checkRightResult :: Connection -> Int -> Assertion @@ -135,9 +134,11 @@ checkRightResult conn n = readNextData conn >>= -- | Asserts that (Left _) as result exists in the received data messages. checkInvalidResult :: Connection -> Int -> Assertion checkInvalidResult conn 0 = assertFailure "Result is right" -checkInvalidResult conn n = readNextData conn >>= - either (const $ pure ()) - (const $ checkInvalidResult conn (n -1)) +checkInvalidResult conn n = do + msgs <- collectUntilReadyForQuery conn + let r = (length . findAllErrors) <$> msgs + either (const $ assertFailure "ReceiverError") (\x -> assertBool "Got errors" (x > 0)) r + -- | Diffirent invalid queries in batches. testInvalidBatch :: IO () @@ -169,7 +170,6 @@ testValidAfterError = withConnection $ \c -> do invalidQuery = Query "SELECT $1" [] Text Text NeverCache sendBatchAndSync c [invalidQuery] checkInvalidResult c 1 - waitReadyForQuery c sendBatchAndSync c [rightQuery] r <- readNextData c @@ -178,7 +178,7 @@ testValidAfterError = withConnection $ \c -> do -- | Describes usual statement. testDescribeStatement :: IO () -testDescribeStatement = withConnectionCommon $ \c -> do +testDescribeStatement = withConnection $ \c -> do r <- describeStatement c $ "select typname, typnamespace, typowner, typlen, typbyval," <> "typcategory, typispreferred, typisdefined, typdelim, typrelid," @@ -188,21 +188,21 @@ testDescribeStatement = withConnectionCommon $ \c -> do -- | Describes statement that returns no data. testDescribeStatementNoData :: IO () -testDescribeStatementNoData = withConnectionCommon $ \c -> do +testDescribeStatementNoData = withConnection $ \c -> do r <- fromRight <$> describeStatement c "SET client_encoding TO UTF8" assertBool "Should be empty" $ null (fst r) assertBool "Should be empty" $ null (snd r) -- | Describes statement that is empty string. testDescribeStatementEmpty :: IO () -testDescribeStatementEmpty = withConnectionCommon $ \c -> do +testDescribeStatementEmpty = withConnection $ \c -> do r <- fromRight <$> describeStatement c "" assertBool "Should be empty" $ null (fst r) assertBool "Should be empty" $ null (snd r) -- | Query using simple query protocol. testSimpleQuery :: IO () -testSimpleQuery = withConnectionCommon $ \c -> do +testSimpleQuery = withConnection $ \c -> do r <- sendSimpleQuery c $ "DROP TABLE IF EXISTS a;" <> "CREATE TABLE a(v int);" @@ -236,8 +236,7 @@ testPreparedStatementCache = withConnection $ \c -> do testLargeQuery :: IO () testLargeQuery = withConnection $ \c -> do sendBatchAndSync c [Query largeStmt [] Text Text NeverCache ] - r <- readNextData c - waitReadyForQuery c + r <- readAllData c assertBool "Should be Right" $ isRight r where largeStmt = "select typname, typnamespace, typowner, typlen, typbyval," @@ -248,16 +247,15 @@ testCorrectDatarows :: IO () testCorrectDatarows = withConnection $ \c -> do let stmt = "SELECT * FROM generate_series(1, 1000)" sendBatchAndSync c [Query stmt [] Text Text NeverCache] - r <- readNextData c + r <- readAllData c case r of Left e -> error $ show e Right rows -> do - map (BS.pack . show ) [1 .. 1000] @=? V.toList (decodeManyRows decodeDataRow rows) + map (BS.pack . show ) [1 .. 1000] @=? (decodeManyRows decodeDataRow rows) countDataRows rows @=? 1000 where -- TODO Right parser later decodeDataRow :: Decode B.ByteString decodeDataRow = do - decodeHeader getInt16BE getByteString . fromIntegral =<< getInt32BE diff --git a/tests/Fault.hs b/tests/Fault.hs index fdf87af..a937b07 100644 --- a/tests/Fault.hs +++ b/tests/Fault.hs @@ -56,8 +56,8 @@ testBatchNextData interruptAction = withConnection $ \c -> do r <- readNextData c assertUnexpected r -testSimpleQuery :: (ConnectionCommon -> IO ()) -> IO () -testSimpleQuery interruptAction = withConnectionCommon $ \c -> do +testSimpleQuery :: (Connection -> IO ()) -> IO () +testSimpleQuery interruptAction = withConnection $ \c -> do asyncVar <- async $ sendSimpleQuery c "SELECT pg_sleep(5)" -- Make sure that query was sent. threadDelay 500000 @@ -73,26 +73,26 @@ testBatchReceiverKilledBefore = withConnection $ \c -> do assertUnexpected r testSimpleQueryReceiverKilledBefore :: IO () -testSimpleQueryReceiverKilledBefore = withConnectionCommon $ \c -> do +testSimpleQueryReceiverKilledBefore = withConnection $ \c -> do killReceiverThread c asyncVar <- async $ sendSimpleQuery c "SELECT pg_sleep(5)" r <- wait asyncVar assertUnexpected r -closeSocket :: AbsConnection c -> IO () +closeSocket :: Connection -> IO () closeSocket = rClose . connRawConnection -throwSocketException :: AbsConnection c -> IO () +throwSocketException :: Connection -> IO () throwSocketException conn = do let exc = SocketException 2 maybe (pure ()) (`throwTo` exc) =<< deRefWeak (connReceiverThread conn) -throwOtherException :: AbsConnection c -> IO () +throwOtherException :: Connection -> IO () throwOtherException conn = do let exc = PatternMatchFail "custom exc" maybe (pure ()) (`throwTo` exc) =<< deRefWeak (connReceiverThread conn) -killReceiverThread :: AbsConnection c -> IO () +killReceiverThread :: Connection -> IO () killReceiverThread conn = maybe (pure ()) killThread =<< deRefWeak (connReceiverThread conn) diff --git a/tests/Protocol.hs b/tests/Protocol.hs index 576f3c9..a9fe266 100644 --- a/tests/Protocol.hs +++ b/tests/Protocol.hs @@ -26,7 +26,7 @@ testProtocolMessages = testGroup "Protocol messages" -- | Tests multi-command simple query. testSimpleQuery :: IO () -testSimpleQuery = withConnectionCommonAll $ \c -> do +testSimpleQuery = withConnectionAll $ \c -> do let rawConn = connRawConnection c statement = StatementSQL $ "DROP TABLE IF EXISTS a;" @@ -44,7 +44,7 @@ testSimpleQuery = withConnectionCommonAll $ \c -> do -- Tests all messages that are permitted in extended query protocol. testExtendedQuery :: IO () -testExtendedQuery = withConnectionCommonAll $ \c -> do +testExtendedQuery = withConnectionAll $ \c -> do let rawConn = connRawConnection c sname = StatementName "statement" pname = PortalName "portal" @@ -88,7 +88,7 @@ testExtendedQuery = withConnectionCommonAll $ \c -> do -- | Tests that PostgreSQL returns `EmptyQueryResponse` when a query -- string is empty. testExtendedEmptyQuery :: IO () -testExtendedEmptyQuery = withConnectionCommonAll $ \c -> do +testExtendedEmptyQuery = withConnectionAll $ \c -> do let rawConn = connRawConnection c sname = StatementName "statement" pname = PortalName "" @@ -108,7 +108,7 @@ testExtendedEmptyQuery = withConnectionCommonAll $ \c -> do -- | Tests that `desribe statement` receives NoData when a statement -- has no data in the result. testExtendedQueryNoData :: IO () -testExtendedQueryNoData = withConnectionCommonAll $ \c -> do +testExtendedQueryNoData = withConnectionAll $ \c -> do let rawConn = connRawConnection c sname = StatementName "statement" statement = StatementSQL "SET client_encoding to UTF8" From d23f36f2a74b8945d7309aad166f8cb389af66b2 Mon Sep 17 00:00:00 2001 From: qz Date: Mon, 4 Feb 2019 20:57:41 +0300 Subject: [PATCH 2/2] strictness fix and some comments --- src/Database/PostgreSQL/Protocol/DataRows.hs | 2 ++ src/Database/PostgreSQL/Protocol/Decoders.hs | 2 +- tests/Codecs/QuickCheck.hs | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Database/PostgreSQL/Protocol/DataRows.hs b/src/Database/PostgreSQL/Protocol/DataRows.hs index b472272..2c7937e 100644 --- a/src/Database/PostgreSQL/Protocol/DataRows.hs +++ b/src/Database/PostgreSQL/Protocol/DataRows.hs @@ -183,6 +183,8 @@ reverseDataRows = foldlDataRows (flip chunk) Empty {-# INLINE countDataRows #-} countDataRows :: [B.ByteString] -> Int +-- quote from docs: ... This will be followed by a DataRow message for each row being returned to the frontend. +-- So for each DataRow message we can be sure that ByteString payload contains encoded data for exactly one row. countDataRows = length -- FIXME delete later diff --git a/src/Database/PostgreSQL/Protocol/Decoders.hs b/src/Database/PostgreSQL/Protocol/Decoders.hs index 2e184e2..67debb7 100644 --- a/src/Database/PostgreSQL/Protocol/Decoders.hs +++ b/src/Database/PostgreSQL/Protocol/Decoders.hs @@ -78,7 +78,7 @@ decodeServerMessage (Header c len) = case chr $ fromIntegral c of -- Dont parse data rows here. 'D' -> do bs <- getByteString len - pure $ DataRow bs + pure $! DataRow bs 'I' -> pure EmptyQueryResponse 'E' -> ErrorResponse <$> (getByteString len >>= diff --git a/tests/Codecs/QuickCheck.hs b/tests/Codecs/QuickCheck.hs index d931111..51799b1 100644 --- a/tests/Codecs/QuickCheck.hs +++ b/tests/Codecs/QuickCheck.hs @@ -41,7 +41,6 @@ makeCodecProperty c oid encoder fd v = monadicIO $ do decoder = PD.dataRowHeader *> PD.getNonNullable fd r <- run $ do sendBatchAndSync c [q] - -- msgs <- collectUntilReadyForQuery c dr <- readNextData c waitReadyForQuery c either (error . show) (pure . decodeOneRow decoder) dr