4545from enum import Enum
4646from distutils .version import LooseVersion
4747
48- # Try to use psycopg2 by default. If psycopg2 isn't available then use
49- # pg8000 which is slower but much more portable because uses only
50- # pure-Python code
5148try :
52- import psycopg2 as pglib
49+ import asyncpg as pglib
5350except ImportError :
54- try :
55- import pg8000 as pglib
56- except ImportError :
57- raise ImportError ("You must have psycopg2 or pg8000 modules installed" )
51+ raise ImportError ("You must have asyncpg module installed" )
5852
5953# ports used by nodes
6054bound_ports = set ()
@@ -193,26 +187,34 @@ def __init__(self,
193187 password = None ):
194188
195189 # Use default user if not specified
196- username = username or default_username ()
197-
190+ self .username = username or default_username ()
191+ self .dbname = dbname
192+ self .host = host
193+ self .password = password
198194 self .parent_node = parent_node
195+ self .connection = None
196+ self .current_transaction = None
199197
200- self .connection = pglib .connect (
201- database = dbname ,
202- user = username ,
203- port = parent_node .port ,
204- host = host ,
205- password = password )
198+ async def init_connection (self ):
199+ if self .connection :
200+ return
206201
207- self .cursor = self .connection .cursor ()
202+ self .connection = await pglib .connect (
203+ database = self .dbname ,
204+ user = self .username ,
205+ port = self .parent_node .port ,
206+ host = self .host ,
207+ password = self .password )
208208
209- def __enter__ (self ):
209+ async def __aenter__ (self ):
210210 return self
211211
212- def __exit__ (self , type , value , traceback ):
213- self .close ()
212+ async def __aexit__ (self , type , value , traceback ):
213+ await self .close ()
214+
215+ async def begin (self , isolation_level = IsolationLevel .ReadCommitted ):
216+ await self .init_connection ()
214217
215- def begin (self , isolation_level = IsolationLevel .ReadCommitted ):
216218 # yapf: disable
217219 levels = [
218220 'read uncommitted' ,
@@ -245,37 +247,45 @@ def begin(self, isolation_level=IsolationLevel.ReadCommitted):
245247
246248 # Set isolation level
247249 cmd = 'SET TRANSACTION ISOLATION LEVEL {}'
248- self .cursor .execute (cmd .format (isolation_level ))
250+ self .current_transaction = self .connection .transaction ()
251+ await self .current_transaction .start ()
252+ await self .connection .execute (cmd .format (isolation_level ))
249253
250254 return self
251255
252- def commit (self ):
253- self .connection .commit ()
256+ async def commit (self ):
257+ if not self .current_transaction :
258+ raise QueryException ("transaction is not started" )
254259
255- return self
256-
257- def rollback (self ):
258- self .connection .rollback ()
260+ await self .current_transaction .commit ()
261+ self .current_transaction = None
259262
260- return self
263+ async def rollback (self ):
264+ if not self .current_transaction :
265+ raise QueryException ("transaction is not started" )
261266
262- def execute ( self , query , * args ):
263- self .cursor . execute ( query , args )
267+ await self . current_transaction . rollback ()
268+ self .current_transaction = None
264269
265- try :
266- res = self .cursor .fetchall ()
267-
268- # pg8000 might return tuples
269- if isinstance (res , tuple ):
270- res = [tuple (t ) for t in res ]
270+ async def execute (self , query , * args ):
271+ await self .init_connection ()
272+ if self .current_transaction :
273+ return await self .connection .execute (query , * args )
274+ else :
275+ async with self .connection .transaction ():
276+ return await self .connection .execute (query , * args )
271277
272- return res
273- except Exception :
274- return None
278+ async def fetch (self , query , * args ):
279+ await self .init_connection ()
280+ if self .current_transaction :
281+ return await self .connection .fetch (query , * args )
282+ else :
283+ async with self .connection .transaction ():
284+ return await self .connection .fetch (query , * args )
275285
276- def close (self ):
277- self .cursor . close ()
278- self .connection .close ()
286+ async def close (self ):
287+ if self .connection :
288+ await self .connection .close ()
279289
280290
281291class NodeBackup (object ):
@@ -943,7 +953,7 @@ def restore(self, dbname, filename, username=None):
943953
944954 self .psql (dbname = dbname , filename = filename , username = username )
945955
946- def poll_query_until (self ,
956+ async def poll_query_until (self ,
947957 dbname ,
948958 query ,
949959 username = None ,
@@ -973,41 +983,54 @@ def poll_query_until(self,
973983
974984 attempts = 0
975985 while max_attempts == 0 or attempts < max_attempts :
976- try :
977- res = self .execute (dbname = dbname ,
978- query = query ,
979- username = username ,
980- commit = True )
981-
982- if expected is None and res is None :
983- return # done
986+ res = await self .fetch (dbname = dbname ,
987+ query = query ,
988+ username = username ,
989+ commit = True )
984990
985- if res is None :
986- raise QueryException ( 'Query returned None' )
991+ if expected is None and res is None :
992+ return # done
987993
988- if len ( res ) == 0 :
989- raise QueryException ('Query returned 0 rows ' )
994+ if res is None :
995+ raise QueryException ('Query returned None ' )
990996
991- if len (res [ 0 ] ) == 0 :
992- raise QueryException ('Query returned 0 columns ' )
997+ if len (res ) == 0 :
998+ raise QueryException ('Query returned 0 rows ' )
993999
994- if res [0 ][ 0 ] :
995- return # done
1000+ if len ( res [0 ]) == 0 :
1001+ raise QueryException ( 'Query returned 0 columns' )
9961002
997- except pglib .ProgrammingError as e :
998- if raise_programming_error :
999- raise e
1000-
1001- except pglib .InternalError as e :
1002- if raise_internal_error :
1003- raise e
1003+ if res [0 ][0 ]:
1004+ return # done
10041005
10051006 time .sleep (sleep_time )
10061007 attempts += 1
10071008
10081009 raise TimeoutException ('Query timeout' )
10091010
1010- def execute (self , dbname , query , username = None , commit = True ):
1011+ async def execute (self , dbname , query , username = None , commit = True ):
1012+ """
1013+ Execute a query
1014+
1015+ Args:
1016+ dbname: database name to connect to.
1017+ query: query to be executed.
1018+ username: database user name.
1019+ commit: should we commit this query?
1020+
1021+ Returns:
1022+ A list of tuples representing rows.
1023+ """
1024+
1025+ async with self .connect (dbname , username ) as node_con :
1026+ if commit :
1027+ await node_con .begin ()
1028+
1029+ await node_con .execute (query )
1030+ if commit :
1031+ await node_con .commit ()
1032+
1033+ async def fetch (self , dbname , query , username = None , commit = True ):
10111034 """
10121035 Execute a query and return all rows as list.
10131036
@@ -1021,10 +1044,13 @@ def execute(self, dbname, query, username=None, commit=True):
10211044 A list of tuples representing rows.
10221045 """
10231046
1024- with self .connect (dbname , username ) as node_con :
1025- res = node_con .execute (query )
1047+ async with self .connect (dbname , username ) as node_con :
1048+ if commit :
1049+ await node_con .begin ()
1050+
1051+ res = await node_con .fetch (query )
10261052 if commit :
1027- node_con .commit ()
1053+ await node_con .commit ()
10281054 return res
10291055
10301056 def backup (self , username = None , xlog_method = DEFAULT_XLOG_METHOD ):
@@ -1059,7 +1085,7 @@ def replicate(self, name, username=None,
10591085 backup = self .backup (username = username , xlog_method = xlog_method )
10601086 return backup .spawn_replica (name , use_logging = use_logging )
10611087
1062- def catchup (self , username = None ):
1088+ async def catchup (self , username = None ):
10631089 """
10641090 Wait until async replica catches up with its master.
10651091 """
@@ -1080,8 +1106,8 @@ def catchup(self, username=None):
10801106 raise CatchUpException ("Master node is not specified" )
10811107
10821108 try :
1083- lsn = master .execute ('postgres' , poll_lsn )[0 ][0 ]
1084- self .poll_query_until (dbname = 'postgres' ,
1109+ lsn = ( await master .fetch ('postgres' , poll_lsn ) )[0 ][0 ]
1110+ await self .poll_query_until (dbname = 'postgres' ,
10851111 username = username ,
10861112 query = wait_lsn .format (lsn ),
10871113 max_attempts = 0 ) # infinite
0 commit comments