From 02c337551133888dfbde2ed0b2fba7cc0c65429e Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 10 Apr 2023 23:03:36 +0200 Subject: [PATCH 01/23] PBCKP-137 update node.py --- testgres/__init__.py | 4 +- testgres/node.py | 337 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 327 insertions(+), 14 deletions(-) diff --git a/testgres/__init__.py b/testgres/__init__.py index 9d5e37cf..1b33ba3b 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -32,7 +32,7 @@ ProcessType, \ DumpFormat -from .node import PostgresNode +from .node import PostgresNode, NodeApp from .utils import \ reserve_port, \ @@ -53,7 +53,7 @@ "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", "TestgresException", "ExecUtilException", "QueryException", "TimeoutException", "CatchUpException", "StartNodeException", "InitNodeException", "BackupException", "XLogMethod", "IsolationLevel", "NodeStatus", "ProcessType", "DumpFormat", - "PostgresNode", + "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", ] diff --git a/testgres/node.py b/testgres/node.py index 378e6803..0d1232a2 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -2,6 +2,12 @@ import io import os +import random +import shutil +import signal +import threading +from queue import Queue + import psutil import subprocess import time @@ -11,6 +17,15 @@ except ImportError: from collections import Iterable +# we support both pg8000 and psycopg2 +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + from shutil import rmtree from six import raise_from, iteritems, text_type from tempfile import mkstemp, mkdtemp @@ -86,6 +101,10 @@ from .backup import NodeBackup +InternalError = pglib.InternalError +ProgrammingError = pglib.ProgrammingError +OperationalError = pglib.OperationalError + class ProcessProxy(object): """ @@ -140,6 +159,9 @@ def __init__(self, name=None, port=None, base_dir=None): self.utils_log_name = self.utils_log_file self.pg_log_name = self.pg_log_file + # Node state + self.is_started = False + def __enter__(self): return self @@ -629,9 +651,38 @@ def get_control_data(self): return out_dict + def slow_start(self, replica=False): + """ + Starts the PostgreSQL instance and then polls the instance + until it reaches the expected state (primary or replica). The state is checked + using the pg_is_in_recovery() function. + + Args: + replica: If True, waits for the instance to be in recovery (i.e., replica mode). + If False, waits for the instance to be in primary mode. Default is False. + """ + self.start() + + if replica: + query = 'SELECT pg_is_in_recovery()' + else: + query = 'SELECT not pg_is_in_recovery()' + # Call poll_query_until until the expected value is returned + self.poll_query_until( + dbname="template1", + query=query, + suppress={pglib.InternalError, + QueryException, + pglib.ProgrammingError, + pglib.OperationalError}) + + def start(self, params=[], wait=True): """ - Start this node using pg_ctl. + Starts the PostgreSQL node using pg_ctl if node has not been started. + By default, it waits for the operation to complete before returning. + Optionally, it can return immediately without waiting for the start operation + to complete by setting the `wait` parameter to False. Args: params: additional arguments for pg_ctl. @@ -640,14 +691,16 @@ def start(self, params=[], wait=True): Returns: This instance of :class:`.PostgresNode`. """ + if self.is_started: + return self _params = [ - get_bin_path("pg_ctl"), - "-D", self.data_dir, - "-l", self.pg_log_file, - "-w" if wait else '-W', # --wait or --no-wait - "start" - ] + params # yapf: disable + get_bin_path("pg_ctl"), + "-D", self.data_dir, + "-l", self.pg_log_file, + "-w" if wait else '-W', # --wait or --no-wait + "start" + ] + params # yapf: disable try: execute_utility(_params, self.utils_log_file) @@ -657,20 +710,22 @@ def start(self, params=[], wait=True): raise_from(StartNodeException(msg, files), e) self._maybe_start_logger() - + self.is_started = True return self def stop(self, params=[], wait=True): """ - Stop this node using pg_ctl. + Stops the PostgreSQL node using pg_ctl if the node has been started. Args: - params: additional arguments for pg_ctl. - wait: wait until operation completes. + params: A list of additional arguments for pg_ctl. Defaults to None. + wait: If True, waits until the operation is complete. Defaults to True. Returns: This instance of :class:`.PostgresNode`. """ + if not self.is_started: + return self _params = [ get_bin_path("pg_ctl"), @@ -682,9 +737,25 @@ def stop(self, params=[], wait=True): execute_utility(_params, self.utils_log_file) self._maybe_stop_logger() - + self.is_started = False return self + def kill(self, someone=None): + """ + Kills the PostgreSQL node or a specified auxiliary process if the node is running. + + Args: + someone: A key to the auxiliary process in the auxiliary_pids dictionary. + If None, the main PostgreSQL node process will be killed. Defaults to None. + """ + if self.is_started: + sig = signal.SIGKILL if os.name != 'nt' else signal.SIGBREAK + if someone == None: + os.kill(self.pid, sig) + else: + os.kill(self.auxiliary_pids[someone][0], sig) + self.is_started = False + def restart(self, params=[]): """ Restart this node using pg_ctl. @@ -1359,3 +1430,245 @@ def connect(self, username=username, password=password, autocommit=autocommit) # yapf: disable + + def table_checksum(self, table, dbname="postgres"): + """ + Calculate the checksum of a table by hashing its rows. + + The function fetches rows from the table in chunks and calculates the checksum + by summing the hash values of each row. The function uses a separate thread + to fetch rows when there are more than 2000 rows in the table. + + Args: + table (str): The name of the table for which the checksum should be calculated. + dbname (str, optional): The name of the database where the table is located. Defaults to "postgres". + + Returns: + int: The calculated checksum of the table. + """ + + def fetch_rows(con, cursor_name): + while True: + rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") + if not rows: + break + yield rows + + def process_rows(queue, con, cursor_name): + try: + for rows in fetch_rows(con, cursor_name): + queue.put(rows) + except Exception as e: + queue.put(e) + else: + queue.put(None) + + cursor_name = f"cur_{random.randint(0, 2 ** 48)}" + checksum = 0 + query_thread = None + + with self.connect(dbname=dbname) as con: + con.execute(f""" + DECLARE {cursor_name} NO SCROLL CURSOR FOR + SELECT t::text FROM {table} as t + """) + + queue = Queue(maxsize=50) + initial_rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") + + if not initial_rows: + return 0 + + queue.put(initial_rows) + + if len(initial_rows) == 2000: + query_thread = threading.Thread(target=process_rows, args=(queue, con, cursor_name)) + query_thread.start() + else: + queue.put(None) + + while True: + rows = queue.get() + if rows is None: + break + if isinstance(rows, Exception): + raise rows + + for row in rows: + checksum += hash(row[0]) + + if query_thread is not None: + query_thread.join() + + con.execute(f"CLOSE {cursor_name}; ROLLBACK;") + + return checksum + + def pgbench_table_checksums(self, dbname="postgres", + pgbench_tables=('pgbench_branches', + 'pgbench_tellers', + 'pgbench_accounts', + 'pgbench_history') + ): + """ + Calculate the checksums of the specified pgbench tables using table_checksum method. + + Args: + dbname (str, optional): The name of the database where the pgbench tables are located. Defaults to "postgres". + pgbench_tables (tuple of str, optional): A tuple containing the names of the pgbench tables for which the + checksums should be calculated. Defaults to a tuple containing the + names of the default pgbench tables. + + Returns: + set of tuple: A set of tuples, where each tuple contains the table name and its corresponding checksum. + """ + return {(table, self.table_checksum(table, dbname)) + for table in pgbench_tables} + + def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): + """ + Update or remove configuration options in the specified configuration file, + updates the options specified in the options dictionary, removes any options + specified in the rm_options set, and writes the updated configuration back to + the file. + + Args: + options (dict): A dictionary containing the options to update or add, + with the option names as keys and their values as values. + config (str, optional): The name of the configuration file to update. + Defaults to 'postgresql.auto.conf'. + rm_options (set, optional): A set containing the names of the options to remove. + Defaults to an empty set. + """ + # parse postgresql.auto.conf + path = os.path.join(self.data_dir, config) + + with open(path, 'r') as f: + raw_content = f.read() + + current_options = {} + current_directives = [] + for line in raw_content.splitlines(): + + # ignore comments + if line.startswith('#'): + continue + + if line == '': + continue + + if line.startswith('include'): + current_directives.append(line) + continue + + name, var = line.partition('=')[::2] + name = name.strip() + var = var.strip() + var = var.strip('"') + var = var.strip("'") + + # remove options specified in rm_options list + if name in rm_options: + continue + + current_options[name] = var + + for option in options: + current_options[option] = options[option] + + auto_conf = '' + for option in current_options: + auto_conf += "{0} = '{1}'\n".format( + option, current_options[option]) + + for directive in current_directives: + auto_conf += directive + "\n" + + with open(path, 'wt') as f: + f.write(auto_conf) + + +class NodeApp: + """ + Functions that can be moved to testgres.PostgresNode + We use these functions in ProbackupController and need tp move them in some visible place + """ + + def __init__(self, test_path, nodes_to_cleanup): + self.test_path = test_path + self.nodes_to_cleanup = nodes_to_cleanup + + def make_empty( + self, + base_dir=None): + real_base_dir = os.path.join(self.test_path, base_dir) + shutil.rmtree(real_base_dir, ignore_errors=True) + os.makedirs(real_base_dir) + + node = PostgresNodeExtended(base_dir=real_base_dir) + node.should_rm_dirs = True + self.nodes_to_cleanup.append(node) + + return node + + def make_simple( + self, + base_dir=None, + set_replication=False, + ptrack_enable=False, + initdb_params=[], + pg_options={}): + + node = self.make_empty(base_dir) + node.init( + initdb_params=initdb_params, allow_streaming=set_replication) + + # set major version + with open(os.path.join(node.data_dir, 'PG_VERSION')) as f: + node.major_version_str = str(f.read().rstrip()) + node.major_version = float(node.major_version_str) + + # Sane default parameters + options = {} + options['max_connections'] = 100 + options['shared_buffers'] = '10MB' + options['fsync'] = 'off' + + options['wal_level'] = 'logical' + options['hot_standby'] = 'off' + + options['log_line_prefix'] = '%t [%p]: [%l-1] ' + options['log_statement'] = 'none' + options['log_duration'] = 'on' + options['log_min_duration_statement'] = 0 + options['log_connections'] = 'on' + options['log_disconnections'] = 'on' + options['restart_after_crash'] = 'off' + options['autovacuum'] = 'off' + + # Allow replication in pg_hba.conf + if set_replication: + options['max_wal_senders'] = 10 + + if ptrack_enable: + options['ptrack.map_size'] = '1' + options['shared_preload_libraries'] = 'ptrack' + + if node.major_version >= 13: + options['wal_keep_size'] = '200MB' + else: + options['wal_keep_segments'] = '12' + + # set default values + node.set_auto_conf(options) + + # Apply given parameters + node.set_auto_conf(pg_options) + + # kludge for testgres + # https://github.com/postgrespro/testgres/issues/54 + # for PG >= 13 remove 'wal_keep_segments' parameter + if node.major_version >= 13: + node.set_auto_conf({}, 'postgresql.conf', ['wal_keep_segments']) + + return node From 1512afde8a40bee606046e6a305aa4017ec8419a Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 11 Apr 2023 15:02:00 +0200 Subject: [PATCH 02/23] PBCKP-137 up version 1.8.6 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a5dc600e..5c6f4a07 100755 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ readme = f.read() setup( - version='1.8.5', + version='1.8.6', name='testgres', packages=['testgres'], description='Testing utility for PostgreSQL and its extensions', From 0d62e0e6881a8cd18e9acd58507fcae74ce71ad9 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 11 Apr 2023 22:50:33 +0200 Subject: [PATCH 03/23] PBCKP-137 update node.py --- testgres/node.py | 163 +++++++++++++++++++---------------------------- 1 file changed, 67 insertions(+), 96 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 0d1232a2..6d1d4544 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -12,6 +12,7 @@ import subprocess import time + try: from collections.abc import Iterable except ImportError: @@ -104,6 +105,7 @@ InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError +DatabaseError = pglib.DatabaseError class ProcessProxy(object): @@ -651,13 +653,15 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False): + def slow_start(self, replica=False, dbname='template1', username='dev'): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked using the pg_is_in_recovery() function. Args: + dbname: + username: replica: If True, waits for the instance to be in recovery (i.e., replica mode). If False, waits for the instance to be in primary mode. Default is False. """ @@ -668,14 +672,15 @@ def slow_start(self, replica=False): else: query = 'SELECT not pg_is_in_recovery()' # Call poll_query_until until the expected value is returned - self.poll_query_until( - dbname="template1", - query=query, - suppress={pglib.InternalError, - QueryException, - pglib.ProgrammingError, - pglib.OperationalError}) - + self.poll_query_until(query=query, + expected=False, + dbname=dbname, + username=username, + suppress={InternalError, + QueryException, + ProgrammingError, + OperationalError, + DatabaseError}) def start(self, params=[], wait=True): """ @@ -1432,96 +1437,66 @@ def connect(self, autocommit=autocommit) # yapf: disable def table_checksum(self, table, dbname="postgres"): - """ - Calculate the checksum of a table by hashing its rows. - - The function fetches rows from the table in chunks and calculates the checksum - by summing the hash values of each row. The function uses a separate thread - to fetch rows when there are more than 2000 rows in the table. - - Args: - table (str): The name of the table for which the checksum should be calculated. - dbname (str, optional): The name of the database where the table is located. Defaults to "postgres". - - Returns: - int: The calculated checksum of the table. - """ - - def fetch_rows(con, cursor_name): - while True: - rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") - if not rows: - break - yield rows - - def process_rows(queue, con, cursor_name): - try: - for rows in fetch_rows(con, cursor_name): - queue.put(rows) - except Exception as e: - queue.put(e) - else: - queue.put(None) - - cursor_name = f"cur_{random.randint(0, 2 ** 48)}" - checksum = 0 - query_thread = None - - with self.connect(dbname=dbname) as con: - con.execute(f""" - DECLARE {cursor_name} NO SCROLL CURSOR FOR - SELECT t::text FROM {table} as t - """) - - queue = Queue(maxsize=50) - initial_rows = con.execute(f"FETCH FORWARD 2000 FROM {cursor_name}") - - if not initial_rows: - return 0 - - queue.put(initial_rows) - - if len(initial_rows) == 2000: - query_thread = threading.Thread(target=process_rows, args=(queue, con, cursor_name)) - query_thread.start() - else: - queue.put(None) + con = self.connect(dbname=dbname) + + curname = "cur_" + str(random.randint(0, 2 ** 48)) + + con.execute(""" + DECLARE %s NO SCROLL CURSOR FOR + SELECT t::text FROM %s as t + """ % (curname, table)) + + que = Queue(maxsize=50) + sum = 0 + + rows = con.execute("FETCH FORWARD 2000 FROM %s" % curname) + if not rows: + return 0 + que.put(rows) + + th = None + if len(rows) == 2000: + def querier(): + try: + while True: + rows = con.execute("FETCH FORWARD 2000 FROM %s" % curname) + if not rows: + break + que.put(rows) + except Exception as e: + que.put(e) + else: + que.put(None) - while True: - rows = queue.get() - if rows is None: - break - if isinstance(rows, Exception): - raise rows + th = threading.Thread(target=querier) + th.start() + else: + que.put(None) - for row in rows: - checksum += hash(row[0]) + while True: + rows = que.get() + if rows is None: + break + if isinstance(rows, Exception): + raise rows + # hash uses SipHash since Python3.4, therefore it is good enough + for row in rows: + sum += hash(row[0]) - if query_thread is not None: - query_thread.join() + if th is not None: + th.join() - con.execute(f"CLOSE {cursor_name}; ROLLBACK;") + con.execute("CLOSE %s; ROLLBACK;" % curname) - return checksum + con.close() + return sum def pgbench_table_checksums(self, dbname="postgres", - pgbench_tables=('pgbench_branches', - 'pgbench_tellers', - 'pgbench_accounts', - 'pgbench_history') + pgbench_tables = ('pgbench_branches', + 'pgbench_tellers', + 'pgbench_accounts', + 'pgbench_history') ): - """ - Calculate the checksums of the specified pgbench tables using table_checksum method. - - Args: - dbname (str, optional): The name of the database where the pgbench tables are located. Defaults to "postgres". - pgbench_tables (tuple of str, optional): A tuple containing the names of the pgbench tables for which the - checksums should be calculated. Defaults to a tuple containing the - names of the default pgbench tables. - - Returns: - set of tuple: A set of tuples, where each tuple contains the table name and its corresponding checksum. - """ return {(table, self.table_checksum(table, dbname)) for table in pgbench_tables} @@ -1589,10 +1564,6 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): class NodeApp: - """ - Functions that can be moved to testgres.PostgresNode - We use these functions in ProbackupController and need tp move them in some visible place - """ def __init__(self, test_path, nodes_to_cleanup): self.test_path = test_path @@ -1605,7 +1576,7 @@ def make_empty( shutil.rmtree(real_base_dir, ignore_errors=True) os.makedirs(real_base_dir) - node = PostgresNodeExtended(base_dir=real_base_dir) + node = PostgresNode(base_dir=real_base_dir) node.should_rm_dirs = True self.nodes_to_cleanup.append(node) From 8be1b3a72cecd7dd15862c3258b97fb5834e6737 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 17 Apr 2023 10:43:16 +0200 Subject: [PATCH 04/23] PBCKP-137 update node --- testgres/node.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 6d1d4544..17b9a260 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -105,7 +105,6 @@ InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError -DatabaseError = pglib.DatabaseError class ProcessProxy(object): @@ -653,7 +652,7 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False, dbname='template1', username='dev'): + def slow_start(self, replica=False, dbname='template1', username=default_username()): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked @@ -673,14 +672,12 @@ def slow_start(self, replica=False, dbname='template1', username='dev'): query = 'SELECT not pg_is_in_recovery()' # Call poll_query_until until the expected value is returned self.poll_query_until(query=query, - expected=False, dbname=dbname, username=username, suppress={InternalError, QueryException, ProgrammingError, - OperationalError, - DatabaseError}) + OperationalError}) def start(self, params=[], wait=True): """ @@ -970,7 +967,7 @@ def psql(self, return process.returncode, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) - def safe_psql(self, query=None, **kwargs): + def safe_psql(self, query=None, expect_error=False, **kwargs): """ Execute a query using psql. @@ -980,6 +977,8 @@ def safe_psql(self, query=None, **kwargs): dbname: database name to connect to. username: database user name. input: raw input to be passed. + expect_error: if True - fail if we didn't get ret + if False - fail if we got ret **kwargs are passed to psql(). @@ -992,7 +991,12 @@ def safe_psql(self, query=None, **kwargs): ret, out, err = self.psql(query=query, **kwargs) if ret: - raise QueryException((err or b'').decode('utf-8'), query) + if expect_error: + out = (err or b'').decode('utf-8') + else: + raise QueryException((err or b'').decode('utf-8'), query) + elif expect_error: + assert False, f"Exception was expected, but query finished successfully: `{query}` " return out From 51f05de66ebc5604dc72ff17af08dd7d4fb1c9a1 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 2 May 2023 15:39:40 +0200 Subject: [PATCH 05/23] PBCKP-152 change local function on execution by ssh --- setup.py | 4 +- testgres/api.py | 6 + testgres/backup.py | 13 +- testgres/cache.py | 20 +-- testgres/config.py | 21 ++- testgres/connection.py | 39 +++--- testgres/defaults.py | 16 ++- testgres/logger.py | 14 ++ testgres/node.py | 23 +++- testgres/os_ops.py | 293 +++++++++++++++++++++++++++++++++++++++++ testgres/utils.py | 21 ++- 11 files changed, 422 insertions(+), 48 deletions(-) create mode 100644 testgres/os_ops.py diff --git a/setup.py b/setup.py index 5c6f4a07..8ae54e4f 100755 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ from distutils.core import setup # Basic dependencies -install_requires = ["pg8000", "port-for>=0.4", "six>=1.9.0", "psutil"] +install_requires = ["pg8000", "port-for>=0.4", "six>=1.9.0", "psutil", "fabric"] # Add compatibility enum class if sys.version_info < (3, 4): @@ -21,7 +21,7 @@ readme = f.read() setup( - version='1.8.6', + version='1.9.0', name='testgres', packages=['testgres'], description='Testing utility for PostgreSQL and its extensions', diff --git a/testgres/api.py b/testgres/api.py index e90cf7bd..bae46717 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -30,6 +30,7 @@ PostgresNode(name='...', port=..., base_dir='...') [(3,)] """ +from defaults import default_username from .node import PostgresNode @@ -37,6 +38,11 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ Simply a wrapper around :class:`.PostgresNode` constructor. See :meth:`.PostgresNode.__init__` for details. + For remote connection you can add next parameters: + host='127.0.0.1', + hostname='localhost', + ssh_key=None, + username=default_username() """ # NOTE: leave explicit 'name' and 'base_dir' for compatibility return PostgresNode(name=name, base_dir=base_dir, **kwargs) diff --git a/testgres/backup.py b/testgres/backup.py index a725a1df..b3ffa833 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -6,6 +6,7 @@ from six import raise_from from tempfile import mkdtemp +from os_ops import OsOperations from .enums import XLogMethod from .consts import \ @@ -47,7 +48,7 @@ def __init__(self, username: database user name. xlog_method: none | fetch | stream (see docs) """ - + self.os_ops = node.os_ops if not node.status(): raise BackupException('Node must be running') @@ -60,8 +61,8 @@ def __init__(self, raise BackupException(msg) # Set default arguments - username = username or default_username() - base_dir = base_dir or mkdtemp(prefix=TMP_BACKUP) + username = username or self.os_ops.get_user() + base_dir = base_dir or self.os_ops.mkdtemp(prefix=TMP_BACKUP) # public self.original_node = node @@ -81,7 +82,7 @@ def __init__(self, "-D", data_dir, "-X", xlog_method.value ] # yapf: disable - execute_utility(_params, self.log_file) + execute_utility(_params, self.log_file, hostname=node.hostname, ssh_key=node.ssh_key) def __enter__(self): return self @@ -107,7 +108,7 @@ def _prepare_dir(self, destroy): available = not destroy if available: - dest_base_dir = mkdtemp(prefix=TMP_NODE) + dest_base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) data1 = os.path.join(self.base_dir, DATA_DIR) data2 = os.path.join(dest_base_dir, DATA_DIR) @@ -185,4 +186,4 @@ def cleanup(self): if self._available: self._available = False - rmtree(self.base_dir, ignore_errors=True) + self.os_ops.rmdirs(self.base_dir, ignore_errors=True) diff --git a/testgres/cache.py b/testgres/cache.py index c3cd9971..cd40e72f 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -6,6 +6,7 @@ from shutil import copytree from six import raise_from +from os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -21,14 +22,16 @@ execute_utility -def cached_initdb(data_dir, logfile=None, params=None): +def cached_initdb(data_dir, logfile=None, hostname='localhost', ssh_key=None, params=None): """ Perform initdb or use cached node files. """ + os_ops = OsOperations(hostname=hostname, ssh_key=ssh_key) + def call_initdb(initdb_dir, log=None): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] - execute_utility(_params + (params or []), log) + execute_utility(_params + (params or []), log, hostname=hostname, ssh_key=ssh_key) except ExecUtilException as e: raise_from(InitNodeException("Failed to run initdb"), e) @@ -39,13 +42,14 @@ def call_initdb(initdb_dir, log=None): cached_data_dir = testgres_config.cached_initdb_dir # Initialize cached initdb - if not os.path.exists(cached_data_dir) or \ - not os.listdir(cached_data_dir): + + if not os_ops.path_exists(cached_data_dir) or \ + not os_ops.listdir(cached_data_dir): call_initdb(cached_data_dir) try: # Copy cached initdb to current data dir - copytree(cached_data_dir, data_dir) + os_ops.copytree(cached_data_dir, data_dir) # Assign this node a unique system id if asked to if testgres_config.cached_initdb_unique: @@ -53,12 +57,12 @@ def call_initdb(initdb_dir, log=None): # Some users might rely upon unique system ids, but # our initdb caching mechanism breaks this contract. pg_control = os.path.join(data_dir, XLOG_CONTROL_FILE) - with io.open(pg_control, "r+b") as f: - f.write(generate_system_id()) # overwrite id + system_id = generate_system_id() + os_ops.write(pg_control, system_id, truncate=True, binary=True, read_and_write=True) # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile) + execute_utility(_params, logfile, hostname=hostname, ssh_key=ssh_key) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index cfcdadc2..b32536ba 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -43,6 +43,9 @@ class GlobalConfig(object): _cached_initdb_dir = None """ underlying class attribute for cached_initdb_dir property """ + + os_ops = None + """ OsOperation object that allows work on remote host """ @property def cached_initdb_dir(self): """ path to a temp directory for cached initdb. """ @@ -52,8 +55,15 @@ def cached_initdb_dir(self): def cached_initdb_dir(self, value): self._cached_initdb_dir = value + # NOTE: assign initial cached dir for initdb + if self.os_ops: + testgres_config.cached_initdb_dir = self.os_ops.mkdtemp(prefix=TMP_CACHE) + else: + testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) + if value: cached_initdb_dirs.add(value) + return testgres_config.cached_initdb_dir @property def temp_dir(self): @@ -133,9 +143,12 @@ def copy(self): @atexit.register -def _rm_cached_initdb_dirs(): +def _rm_cached_initdb_dirs(os_ops=None): for d in cached_initdb_dirs: - rmtree(d, ignore_errors=True) + if os_ops: + os_ops.rmtree(d, ignore_errors=True) + else: + rmtree(d, ignore_errors=True) def push_config(**options): @@ -195,7 +208,3 @@ def configure_testgres(**options): """ testgres_config.update(options) - - -# NOTE: assign initial cached dir for initdb -testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index ee2a2128..f85f56be 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -1,4 +1,5 @@ # coding: utf-8 +from os_ops import OsOperations # we support both pg8000 and psycopg2 try: @@ -41,11 +42,12 @@ def __init__(self, self._node = node - self._connection = pglib.connect(database=dbname, - user=username, - password=password, - host=node.host, - port=node.port) + self.os_ops = OsOperations(node.host, node.hostname, node.ssh_key, node.username) + self._connection = self.os_ops.db_connect(dbname=dbname, + user=username, + password=password, + host=node.host, + port=node.port) self._connection.autocommit = autocommit self._cursor = self.connection.cursor() @@ -102,17 +104,24 @@ def rollback(self): return self def execute(self, query, *args): - self.cursor.execute(query, args) - try: - res = self.cursor.fetchall() - - # pg8000 might return tuples - if isinstance(res, tuple): - res = [tuple(t) for t in res] - - return res - except Exception: + with self.connection.cursor() as cursor: + cursor.execute(query, args) + try: + res = cursor.fetchall() + + # pg8000 might return tuples + if isinstance(res, tuple): + res = [tuple(t) for t in res] + + return res + except (pglib.ProgrammingError, pglib.InternalError) as e: + # An error occurred while trying to fetch results (e.g., no results to fetch) + print(f"Error fetching results: {e}") + return None + except (pglib.Error, Exception) as e: + # Handle other database errors + print(f"Error executing query: {e}") return None def close(self): diff --git a/testgres/defaults.py b/testgres/defaults.py index 8d5b892e..539183ae 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -13,12 +13,15 @@ def default_dbname(): return 'postgres' -def default_username(): +def default_username(os_ops=None): """ Return default username (current user). """ - - return getpass.getuser() + if os_ops: + user = os_ops.get_user() + else: + user = getpass.getuser() + return user def generate_app_name(): @@ -29,7 +32,7 @@ def generate_app_name(): return 'testgres-{}'.format(str(uuid.uuid4())) -def generate_system_id(): +def generate_system_id(os_ops=None): """ Generate a new 64-bit unique system identifier for node. """ @@ -44,7 +47,10 @@ def generate_system_id(): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - system_id |= (os.getpid() & 0xFFF) + if os_ops: + system_id |= (os_ops.get_pid() & 0xFFF) + else: + system_id |= (os.getpid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/logger.py b/testgres/logger.py index b4648f44..abd4d255 100644 --- a/testgres/logger.py +++ b/testgres/logger.py @@ -6,6 +6,20 @@ import time +# create logger +log = logging.getLogger('Testgres') +log.setLevel(logging.DEBUG) +# create console handler and set level to debug +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) +# create formatter +formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') +# add formatter to ch +ch.setFormatter(formatter) +# add ch to logger +log.addHandler(ch) + + class TestgresLogger(threading.Thread): """ Helper class to implement reading from log files. diff --git a/testgres/node.py b/testgres/node.py index 17b9a260..d895bde8 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -12,6 +12,7 @@ import subprocess import time +from os_ops import OsOperations try: from collections.abc import Iterable @@ -129,7 +130,8 @@ def __repr__(self): class PostgresNode(object): - def __init__(self, name=None, port=None, base_dir=None): + def __init__(self, name=None, port=None, base_dir=None, + host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): """ PostgresNode constructor. @@ -147,10 +149,14 @@ def __init__(self, name=None, port=None, base_dir=None): self._master = None # basic - self.host = '127.0.0.1' self.name = name or generate_app_name() self.port = port or reserve_port() + self.host = host + self.hostname = hostname + self.ssh_key = ssh_key + self.os_ops = OsOperations(host, hostname, ssh_key, username=username) + # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit @@ -455,9 +461,12 @@ def init(self, initdb_params=None, **kwargs): """ # initialize this PostgreSQL node - cached_initdb(data_dir=self.data_dir, - logfile=self.utils_log_file, - params=initdb_params) + cached_initdb( + data_dir=self.data_dir, + logfile=self.utils_log_file, + hostname=self.hostname, + ssh_key=self.ssh_key, + params=initdb_params) # initialize default config files self.default_conf(**kwargs) @@ -514,6 +523,10 @@ def get_auth_method(t): new_lines = [ u"local\treplication\tall\t\t\t{}\n".format(auth_local), u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), + + u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), + u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) ] # yapf: disable diff --git a/testgres/os_ops.py b/testgres/os_ops.py new file mode 100644 index 00000000..290fee77 --- /dev/null +++ b/testgres/os_ops.py @@ -0,0 +1,293 @@ +import base64 +import getpass +import os +import shutil +import subprocess +import tempfile +from contextlib import contextmanager +from shutil import rmtree + +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + +from defaults import default_username +from testgres.logger import log + +import paramiko + + +class OsOperations: + + def __init__(self, host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): + self.host = host + self.ssh_key = ssh_key + self.username = username + self.remote = not (self.host == '127.0.0.1' and hostname == 'localhost') + self.ssh = None + + if self.remote: + self.ssh = self.connect() + + def __del__(self): + if self.ssh: + self.ssh.close() + + @contextmanager + def ssh_connect(self): + if not self.remote: + yield None + else: + with open(self.ssh_key, 'r') as f: + key_data = f.read() + if 'BEGIN OPENSSH PRIVATE KEY' in key_data: + key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) + else: + key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + + with paramiko.SSHClient() as ssh: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + yield ssh + + def connect(self): + with self.ssh_connect() as ssh: + return ssh + + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + if isinstance(cmd, list): + cmd = ' '.join(cmd) + log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") + # Source global profile file + execute command + try: + if self.remote: + cmd = f"source /etc/profile.d/custom.sh; {cmd}" + with self.ssh_connect() as ssh: + stdin, stdout, stderr = ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + result = stdout.read().decode('utf-8') + error = stderr.read().decode('utf-8') + else: + process = subprocess.run(cmd, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + timeout=60) + exit_status = process.returncode + result = process.stdout + error = process.stderr + + if expect_error: + raise Exception(result, error) + if exit_status != 0 or 'error' in error.lower(): + log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + exit(1) + + if verbose: + return exit_status, result, error + else: + return result + + except Exception as e: + log.error(f"Unexpected error while executing command `{cmd}`: {e}") + return None + + def makedirs(self, path, remove_existing=False): + if remove_existing: + cmd = f'rm -rf {path} && mkdir -p {path}' + else: + cmd = f'mkdir -p {path}' + self.exec_command(cmd) + + def rmdirs(self, path, ignore_errors=True): + if self.remote: + cmd = f'rm -rf {path}' + self.exec_command(cmd) + else: + rmtree(path, ignore_errors=ignore_errors) + + def mkdtemp(self, prefix=None): + if self.remote: + temp_dir = self.exec_command(f'mkdtemp -d {prefix}') + return temp_dir.strip() + else: + return tempfile.mkdtemp(prefix=prefix) + + def path_exists(self, path): + if self.remote: + result = self.exec_command(f'test -e {path}; echo $?') + return int(result.strip()) == 0 + else: + return os.path.exists(path) + + def copytree(self, src, dst): + if self.remote: + self.exec_command(f'cp -r {src} {dst}') + else: + shutil.copytree(src, dst) + + def listdir(self, path): + if self.remote: + result = self.exec_command(f'ls {path}') + return result.splitlines() + else: + return os.listdir(path) + + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file, both locally and on a remote host. + + :param filename: The file path where the data will be written. + :param data: The data to be written to the file. + :param truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + :param binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + :param read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). + """ + mode = 'wb' if binary else 'w' + if not truncate: + mode = 'a' + mode + if read_and_write: + mode = 'r+' + mode + + if self.remote: + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file.write(data) + tmp_file.flush() + + sftp = self.ssh.open_sftp() + sftp.put(tmp_file.name, filename) + sftp.close() + else: + with open(filename, mode) as file: + file.write(data) + + def read(self, filename): + cmd = f'cat {filename}' + return self.exec_command(cmd) + + def readlines(self, filename): + return self.read(filename).splitlines() + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd).strip() + + def kill(self, pid, signal): + cmd = f'kill -{signal} {pid}' + self.exec_command(cmd) + + def environ(self, var_name): + cmd = f"echo ${var_name}" + return self.exec_command(cmd).strip() + + @property + def pathsep(self): + return ':' if self.get_name() == 'posix' else ';' + + def isfile(self, remote_file): + if self.remote: + stdout = self.exec_command(f'test -f {remote_file}; echo $?') + result = int(stdout.strip()) + return result == 0 + else: + return os.path.isfile(remote_file) + + def find_executable(self, executable): + search_paths = self.environ('PATH') + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + if self.remote: + if not self.exec_command(f"test -x {file} && echo OK") == 'OK\n': + return False + else: + if not os.access(file, os.X_OK): + return False + return True + + def add_to_path(self, new_path): + os_name = self.get_name() + if os_name == 'posix': + dir_del = ':' + elif os_name == 'nt': + dir_del = ';' + else: + raise Exception(f"Unsupported operating system: {os_name}") + + # Check if the directory is already in PATH + path = self.environ('PATH') + if new_path not in path.split(dir_del): + if self.remote: + self.exec_command(f"export PATH={new_path}{dir_del}{path}") + else: + os.environ['PATH'] = f"{new_path}{dir_del}{path}" + return dir_del + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + if self.remote: + self.exec_command(f"export {var_name}={var_val}") + else: + os.environ[var_name] = var_val + + def get_pid(self): + # Get current process id + if self.remote: + process_id = self.exec_command(f"echo $$") + else: + process_id = os.getpid() + return process_id + + def get_user(self): + # Get current user + if self.remote: + user = self.exec_command(f"echo $USER") + else: + user = getpass.getuser() + return user + + @contextmanager + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + if self.remote: + with self.ssh_connect() as ssh: + # Set up a local port forwarding on a random port + local_port = ssh.forward_remote_port(host, port) + conn = pglib.connect( + host=host, + port=local_port, + dbname=dbname, + user=user, + password=password, + ) + try: + yield conn + finally: + conn.close() + ssh.close_forwarded_tcp(local_port) + else: + with pglib.connect( + host=host, + port=port, + dbname=dbname, + user=user, + password=password, + ) as conn: + yield conn + + diff --git a/testgres/utils.py b/testgres/utils.py index 4d99c69d..877549e7 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -15,6 +15,8 @@ from distutils.spawn import find_executable from six import iteritems +from fabric import Connection + from .config import testgres_config from .exceptions import ExecUtilException @@ -47,7 +49,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None): +def execute_utility(args, logfile=None, hostname='localhost', ssh_key=None): """ Execute utility (pg_ctl, pg_dump etc). @@ -59,6 +61,23 @@ def execute_utility(args, logfile=None): stdout of executed utility. """ + if hostname != 'localhost': + conn = Connection( + hostname, + connect_kwargs={ + "key_filename": f"{ssh_key}", + }, + ) + + # TODO skip remote ssh run if we are on the localhost. + # result = conn.run('hostname', hide=True) + # add logger + + cmd = ' '.join(args) + result = conn.run(cmd, hide=True) + + return result + # run utility if os.name == 'nt': # using output to a temporary file in Windows From 4f38bd505ee99d09f40478a5cd30a1a685481c5b Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 3 May 2023 11:50:35 +0200 Subject: [PATCH 06/23] PBCKP-152 multihost --- testgres/cache.py | 2 +- testgres/config.py | 10 ++++------ testgres/connection.py | 3 +-- testgres/os_ops.py | 31 ++++++++++++------------------- 4 files changed, 18 insertions(+), 28 deletions(-) diff --git a/testgres/cache.py b/testgres/cache.py index cd40e72f..0ca4d707 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -39,7 +39,7 @@ def call_initdb(initdb_dir, log=None): call_initdb(data_dir, logfile) else: # Fetch cached initdb dir - cached_data_dir = testgres_config.cached_initdb_dir + cached_data_dir = testgres_config.cached_initdb_dir() # Initialize cached initdb diff --git a/testgres/config.py b/testgres/config.py index b32536ba..2e4a3aaa 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -55,12 +55,6 @@ def cached_initdb_dir(self): def cached_initdb_dir(self, value): self._cached_initdb_dir = value - # NOTE: assign initial cached dir for initdb - if self.os_ops: - testgres_config.cached_initdb_dir = self.os_ops.mkdtemp(prefix=TMP_CACHE) - else: - testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) - if value: cached_initdb_dirs.add(value) return testgres_config.cached_initdb_dir @@ -208,3 +202,7 @@ def configure_testgres(**options): """ testgres_config.update(options) + + +# NOTE: assign initial cached dir for initdb +testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index f85f56be..cc3dbdfe 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -42,8 +42,7 @@ def __init__(self, self._node = node - self.os_ops = OsOperations(node.host, node.hostname, node.ssh_key, node.username) - self._connection = self.os_ops.db_connect(dbname=dbname, + self._connection = node.os_ops.db_connect(dbname=dbname, user=username, password=password, host=node.host, diff --git a/testgres/os_ops.py b/testgres/os_ops.py index 290fee77..e87dcc88 100644 --- a/testgres/os_ops.py +++ b/testgres/os_ops.py @@ -262,32 +262,25 @@ def get_user(self): user = getpass.getuser() return user - @contextmanager def db_connect(self, dbname, user, password=None, host='localhost', port=5432): if self.remote: - with self.ssh_connect() as ssh: - # Set up a local port forwarding on a random port - local_port = ssh.forward_remote_port(host, port) - conn = pglib.connect( - host=host, - port=local_port, - dbname=dbname, - user=user, - password=password, - ) - try: - yield conn - finally: - conn.close() - ssh.close_forwarded_tcp(local_port) + local_port = self.ssh.forward_remote_port(host, port) + conn = pglib.connect( + host=host, + port=local_port, + dbname=dbname, + user=user, + password=password, + ) else: - with pglib.connect( + conn = pglib.connect( host=host, port=port, dbname=dbname, user=user, password=password, - ) as conn: - yield conn + ) + return conn + From f9b6bdbf48747f29f29e99977a130258e3bb3836 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Sun, 11 Jun 2023 00:03:28 +0200 Subject: [PATCH 07/23] PBCKP-152 --- testgres/backup.py | 7 +- testgres/cache.py | 12 +- testgres/config.py | 11 +- testgres/connection.py | 2 +- testgres/defaults.py | 18 +-- testgres/node.py | 194 +++++++++++------------ testgres/op_ops/local_ops.py | 224 ++++++++++++++++++++++++++ testgres/op_ops/os_ops.py | 99 ++++++++++++ testgres/op_ops/remote_ops.py | 259 ++++++++++++++++++++++++++++++ testgres/os_ops.py | 285 ---------------------------------- testgres/utils.py | 10 +- 11 files changed, 697 insertions(+), 424 deletions(-) create mode 100644 testgres/op_ops/local_ops.py create mode 100644 testgres/op_ops/os_ops.py create mode 100644 testgres/op_ops/remote_ops.py delete mode 100644 testgres/os_ops.py diff --git a/testgres/backup.py b/testgres/backup.py index 0a5ad67f..c0fd6e50 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -2,7 +2,6 @@ import os -from shutil import rmtree, copytree from six import raise_from from .enums import XLogMethod @@ -14,8 +13,6 @@ PG_CONF_FILE, \ BACKUP_LOG_FILE -from .defaults import default_username - from .exceptions import BackupException from .utils import \ @@ -80,7 +77,7 @@ def __init__(self, "-D", data_dir, "-X", xlog_method.value ] # yapf: disable - execute_utility(_params, self.log_file, hostname=node.hostname, ssh_key=node.ssh_key) + execute_utility(_params, self.log_file, self.os_ops) def __enter__(self): return self @@ -113,7 +110,7 @@ def _prepare_dir(self, destroy): try: # Copy backup to new data dir - copytree(data1, data2) + self.os_ops.copytree(data1, data2) except Exception as e: raise_from(BackupException('Failed to copy files'), e) else: diff --git a/testgres/cache.py b/testgres/cache.py index 6ef92002..4998e0d2 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -4,7 +4,8 @@ from six import raise_from -from .os_ops import OsOperations +from .op_ops.local_ops import LocalOperations +from .op_ops.os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -20,16 +21,15 @@ execute_utility -def cached_initdb(data_dir, logfile=None, hostname='localhost', ssh_key=None, params=None): +def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = LocalOperations()): """ Perform initdb or use cached node files. """ - os_ops = OsOperations(hostname=hostname, ssh_key=ssh_key) def call_initdb(initdb_dir, log=None): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] - execute_utility(_params + (params or []), log, hostname=hostname, ssh_key=ssh_key) + execute_utility(_params + (params or []), log, os_ops) except ExecUtilException as e: raise_from(InitNodeException("Failed to run initdb"), e) @@ -42,7 +42,7 @@ def call_initdb(initdb_dir, log=None): # Initialize cached initdb if not os_ops.path_exists(cached_data_dir) or \ - not os_ops.listdir(cached_data_dir): + not os_ops.listdir(cached_data_dir): call_initdb(cached_data_dir) try: @@ -60,7 +60,7 @@ def call_initdb(initdb_dir, log=None): # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile, hostname=hostname, ssh_key=ssh_key) + execute_utility(_params, logfile, os_ops) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index 2e4a3aaa..1be76fbe 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -6,8 +6,8 @@ from contextlib import contextmanager from shutil import rmtree -from tempfile import mkdtemp +from .op_ops.local_ops import LocalOperations from .consts import TMP_CACHE @@ -137,12 +137,9 @@ def copy(self): @atexit.register -def _rm_cached_initdb_dirs(os_ops=None): +def _rm_cached_initdb_dirs(os_ops=LocalOperations()): for d in cached_initdb_dirs: - if os_ops: - os_ops.rmtree(d, ignore_errors=True) - else: - rmtree(d, ignore_errors=True) + os_ops.rmdirs(d, ignore_errors=True) def push_config(**options): @@ -205,4 +202,4 @@ def configure_testgres(**options): # NOTE: assign initial cached dir for initdb -testgres_config.cached_initdb_dir = mkdtemp(prefix=TMP_CACHE) +testgres_config.cached_initdb_dir = testgres_config.os_ops.mkdtemp(prefix=TMP_CACHE) diff --git a/testgres/connection.py b/testgres/connection.py index 763c2490..f3d01ebe 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -37,7 +37,7 @@ def __init__(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or default_username(node.os_ops) self._node = node diff --git a/testgres/defaults.py b/testgres/defaults.py index 539183ae..cac788b8 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -1,9 +1,10 @@ import datetime -import getpass import os import struct import uuid +from .op_ops.local_ops import LocalOperations + def default_dbname(): """ @@ -13,15 +14,11 @@ def default_dbname(): return 'postgres' -def default_username(os_ops=None): +def default_username(os_ops=LocalOperations()): """ Return default username (current user). """ - if os_ops: - user = os_ops.get_user() - else: - user = getpass.getuser() - return user + return os_ops.get_user() def generate_app_name(): @@ -32,7 +29,7 @@ def generate_app_name(): return 'testgres-{}'.format(str(uuid.uuid4())) -def generate_system_id(os_ops=None): +def generate_system_id(os_ops=LocalOperations()): """ Generate a new 64-bit unique system identifier for node. """ @@ -47,10 +44,7 @@ def generate_system_id(os_ops=None): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - if os_ops: - system_id |= (os_ops.get_pid() & 0xFFF) - else: - system_id |= (os.getpid() & 0xFFF) + system_id |= (os_ops.get_pid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/node.py b/testgres/node.py index 29ba2cf3..f06e5cc9 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -12,6 +12,9 @@ import subprocess import time +from op_ops.local_ops import LocalOperations +from op_ops.os_ops import OsOperations +from op_ops.remote_ops import RemoteOperations try: from collections.abc import Iterable @@ -101,7 +104,6 @@ clean_on_error from .backup import NodeBackup -from .os_ops import OsOperations InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError @@ -156,7 +158,10 @@ def __init__(self, name=None, port=None, base_dir=None, self.host = host self.hostname = hostname self.ssh_key = ssh_key - self.os_ops = OsOperations(host, hostname, ssh_key, username=username) + if hostname == 'localhost' or host == '127.0.0.1': + self.os_ops = LocalOperations(username=username) + else: + self.os_ops = RemoteOperations(host, hostname, ssh_key) # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit @@ -201,8 +206,9 @@ def pid(self): if self.status(): pid_file = os.path.join(self.data_dir, PG_PID_FILE) - with io.open(pid_file) as f: - return int(f.readline()) + lines = self.os_ops.readlines(pid_file, num_lines=1) + pid = int(lines[0]) if lines else None + return pid # for clarity return 0 @@ -280,11 +286,11 @@ def master(self): @property def base_dir(self): if not self._base_dir: - self._base_dir = mkdtemp(prefix=TMP_NODE) + self._base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) # NOTE: it's safe to create a new dir - if not os.path.exists(self._base_dir): - os.makedirs(self._base_dir) + if not self.os_ops.exists(self._base_dir): + self.os_ops.makedirs(self._base_dir) return self._base_dir @@ -293,8 +299,8 @@ def logs_dir(self): path = os.path.join(self.base_dir, LOGS_DIR) # NOTE: it's safe to create a new dir - if not os.path.exists(path): - os.makedirs(path) + if not self.os_ops.exists(path): + self.os_ops.makedirs(path) return path @@ -371,9 +377,7 @@ def _create_recovery_conf(self, username, slot=None): # Since 12 recovery.conf had disappeared if self.version >= PgVer('12'): signal_name = os.path.join(self.data_dir, "standby.signal") - # cross-python touch(). It is vulnerable to races, but who cares? - with open(signal_name, 'a'): - os.utime(signal_name, None) + self.os_ops.touch(signal_name) else: line += "standby_mode=on\n" @@ -431,19 +435,13 @@ def _collect_special_files(self): for f, num_lines in files: # skip missing files - if not os.path.exists(f): + if not self.os_ops.path_exists(f): continue - with io.open(f, "rb") as _f: - if num_lines > 0: - # take last N lines of file - lines = b''.join(file_tail(_f, num_lines)).decode('utf-8') - else: - # read whole file - lines = _f.read().decode('utf-8') + lines = b''.join(self.os_ops.readlines(f, num_lines, encoding='utf-8')) - # fill list - result.append((f, lines)) + # fill list + result.append((f, lines)) return result @@ -465,8 +463,7 @@ def init(self, initdb_params=None, **kwargs): cached_initdb( data_dir=self.data_dir, logfile=self.utils_log_file, - hostname=self.hostname, - ssh_key=self.ssh_key, + os_ops=self.os_ops, params=initdb_params) # initialize default config files @@ -498,47 +495,44 @@ def default_conf(self, hba_conf = os.path.join(self.data_dir, HBA_CONF_FILE) # filter lines in hba file - with io.open(hba_conf, "r+") as conf: - # get rid of comments and blank lines - lines = [ - s for s in conf.readlines() - if len(s.strip()) > 0 and not s.startswith('#') - ] - - # write filtered lines - conf.seek(0) - conf.truncate() - conf.writelines(lines) - - # replication-related settings - if allow_streaming: - # get auth method for host or local users - def get_auth_method(t): - return next((s.split()[-1] - for s in lines if s.startswith(t)), 'trust') - - # get auth methods - auth_local = get_auth_method('local') - auth_host = get_auth_method('host') - - new_lines = [ - u"local\treplication\tall\t\t\t{}\n".format(auth_local), - u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), - - u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), - u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), - - u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) - ] # yapf: disable - - # write missing lines - for line in new_lines: - if line not in lines: - conf.write(line) + # get rid of comments and blank lines + hba_conf_file = self.os_ops.readlines(hba_conf) + lines = [ + s for s in hba_conf_file + if len(s.strip()) > 0 and not s.startswith('#') + ] + + # write filtered lines + self.os_ops.write(hba_conf_file, lines, truncate=True) + + # replication-related settings + if allow_streaming: + # get auth method for host or local users + def get_auth_method(t): + return next((s.split()[-1] + for s in lines if s.startswith(t)), 'trust') + + # get auth methods + auth_local = get_auth_method('local') + auth_host = get_auth_method('host') + + new_lines = [ + u"local\treplication\tall\t\t\t{}\n".format(auth_local), + u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), + + u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), + u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), + + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) + ] # yapf: disable + + # write missing lines + for line in new_lines: + if line not in lines: + self.os_ops.write(hba_conf, line) # overwrite config file - with io.open(postgres_conf, "w") as conf: - conf.truncate() + self.os_ops.write(postgres_conf, '', truncate=True) self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, @@ -613,10 +607,10 @@ def append_conf(self, line='', filename=PG_CONF_FILE, **kwargs): lines.append('{} = {}'.format(option, value)) config_name = os.path.join(self.data_dir, filename) - with io.open(config_name, 'a') as conf: - for line in lines: - conf.write(text_type(line)) - conf.write(text_type('\n')) + conf_text = '' + for line in lines: + conf_text += text_type(line) + '\n' + self.os_ops.write(config_name, conf_text) return self @@ -971,10 +965,7 @@ def psql(self, psql_params.append(dbname) # start psql process - process = subprocess.Popen(psql_params, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + process = self.os_ops.exec_command(psql_params) # wait until it finishes and get stdout and stderr out, err = process.communicate(input=input) @@ -1351,7 +1342,7 @@ def pgbench(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or default_username(self.os_ops) _params = [ get_bin_path("pgbench"), @@ -1363,7 +1354,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = subprocess.Popen(_params, stdout=stdout, stderr=stderr) + proc = self.os_ops.exec_command(_params, wait_exit=True) return proc @@ -1403,7 +1394,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # Set default arguments dbname = dbname or default_dbname() - username = username or default_username() + username = username or default_username(os_ops=self.os_ops) _params = [ get_bin_path("pgbench"), @@ -1534,10 +1525,8 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): Defaults to an empty set. """ # parse postgresql.auto.conf - path = os.path.join(self.data_dir, config) - - with open(path, 'r') as f: - raw_content = f.read() + auto_conf_file = os.path.join(self.data_dir, config) + raw_content = self.os_ops.read(auto_conf_file) current_options = {} current_directives = [] @@ -1577,22 +1566,22 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): for directive in current_directives: auto_conf += directive + "\n" - with open(path, 'wt') as f: - f.write(auto_conf) + self.os_ops.write(auto_conf_file, auto_conf) class NodeApp: - def __init__(self, test_path, nodes_to_cleanup): + def __init__(self, test_path, nodes_to_cleanup, os_ops=LocalOperations()): self.test_path = test_path self.nodes_to_cleanup = nodes_to_cleanup + self.os_ops = os_ops def make_empty( self, base_dir=None): real_base_dir = os.path.join(self.test_path, base_dir) - shutil.rmtree(real_base_dir, ignore_errors=True) - os.makedirs(real_base_dir) + self.os_ops.rmdirs(real_base_dir, ignore_errors=True) + self.os_ops.makedirs(real_base_dir) node = PostgresNode(base_dir=real_base_dir) node.should_rm_dirs = True @@ -1615,27 +1604,24 @@ def make_simple( initdb_params=initdb_params, allow_streaming=set_replication) # set major version - with open(os.path.join(node.data_dir, 'PG_VERSION')) as f: - node.major_version_str = str(f.read().rstrip()) - node.major_version = float(node.major_version_str) - - # Sane default parameters - options = {} - options['max_connections'] = 100 - options['shared_buffers'] = '10MB' - options['fsync'] = 'off' - - options['wal_level'] = 'logical' - options['hot_standby'] = 'off' - - options['log_line_prefix'] = '%t [%p]: [%l-1] ' - options['log_statement'] = 'none' - options['log_duration'] = 'on' - options['log_min_duration_statement'] = 0 - options['log_connections'] = 'on' - options['log_disconnections'] = 'on' - options['restart_after_crash'] = 'off' - options['autovacuum'] = 'off' + pg_version_file = self.os_ops.read(os.path.join(node.data_dir, 'PG_VERSION')) + node.major_version_str = str(pg_version_file.rstrip()) + node.major_version = float(node.major_version_str) + + # Set default parameters + options = {'max_connections': 100, + 'shared_buffers': '10MB', + 'fsync': 'off', + 'wal_level': 'logical', + 'hot_standby': 'off', + 'log_line_prefix': '%t [%p]: [%l-1] ', + 'log_statement': 'none', + 'log_duration': 'on', + 'log_min_duration_statement': 0, + 'log_connections': 'on', + 'log_disconnections': 'on', + 'restart_after_crash': 'off', + 'autovacuum': 'off'} # Allow replication in pg_hba.conf if set_replication: diff --git a/testgres/op_ops/local_ops.py b/testgres/op_ops/local_ops.py new file mode 100644 index 00000000..42a3b4b7 --- /dev/null +++ b/testgres/op_ops/local_ops.py @@ -0,0 +1,224 @@ +import getpass +import os +import shutil +import subprocess +import tempfile +from shutil import rmtree + +from testgres.logger import log + +from .os_ops import OsOperations +from .os_ops import pglib + +CMD_TIMEOUT_SEC = 60 + + +class LocalOperations(OsOperations): + + def __init__(self, username=None): + super().__init__() + self.username = username or self.get_user() + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + if isinstance(cmd, list): + cmd = ' '.join(cmd) + log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") + # Source global profile file + execute command + try: + process = subprocess.run(cmd, shell=True, text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=CMD_TIMEOUT_SEC) + exit_status = process.returncode + result = process.stdout + error = process.stderr + + if expect_error: + raise Exception(result, error) + if exit_status != 0 or 'error' in error.lower(): + log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + exit(1) + + if verbose: + return exit_status, result, error + else: + return result + + except Exception as e: + log.error(f"Unexpected error while executing command `{cmd}`: {e}") + return None + + # Environment setup + def environ(self, var_name): + cmd = f"echo ${var_name}" + return self.exec_command(cmd).strip() + + def find_executable(self, executable): + search_paths = self.environ('PATH') + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + return os.access(file, os.X_OK) + + def add_to_path(self, new_path): + pathsep = self.pathsep + # Check if the directory is already in PATH + path = self.environ('PATH') + if new_path not in path.split(pathsep): + if self.remote: + self.exec_command(f"export PATH={new_path}{pathsep}{path}") + else: + os.environ['PATH'] = f"{new_path}{pathsep}{path}" + return pathsep + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + os.environ[var_name] = var_val + + # Get environment variables + def get_user(self): + return getpass.getuser() + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd).strip() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + if remove_existing and os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + + def rmdirs(self, path, ignore_errors=True): + return rmtree(path, ignore_errors=ignore_errors) + + def listdir(self, path): + return os.listdir(path) + + def path_exists(self, path): + return os.path.exists(path) + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == 'posix': + pathsep = ':' + elif os_name == 'nt': + pathsep = ';' + else: + raise Exception(f"Unsupported operating system: {os_name}") + return pathsep + + def mkdtemp(self, prefix=None): + return tempfile.mkdtemp(prefix=prefix) + + def copytree(self, src, dst): + return shutil.copytree(src, dst) + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file locally + Args: + filename: The file path where the data will be written. + data: The data to be written to the file. + truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + """ + mode = 'wb' if binary else 'w' + if not truncate: + mode = 'a' + mode + if read_and_write: + mode = 'r+' + mode + + with open(filename, mode) as file: + if isinstance(data, list): + file.writelines(data) + else: + file.write(data) + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file. + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + # cross-python touch(). It is vulnerable to races, but who cares? + with open(filename, 'a'): + os.utime(filename, None) + + def read(self, filename): + with open(filename, 'r') as file: + return file.read() + + def readlines(self, filename, num_lines=0, encoding=None): + """ + Read lines from a local file. + If num_lines is greater than 0, only the last num_lines lines will be read. + """ + assert num_lines >= 0 + + if num_lines == 0: + with open(filename, 'r', encoding=encoding) as file: + return file.readlines() + + else: + bufsize = 8192 + buffers = 1 + + with open(filename, 'r', encoding=encoding) as file: + file.seek(0, os.SEEK_END) + end_pos = file.tell() + + while True: + offset = max(0, end_pos - bufsize * buffers) + file.seek(offset, os.SEEK_SET) + pos = file.tell() + lines = file.readlines() + cur_lines = len(lines) + + if cur_lines >= num_lines or pos == 0: + return lines[-num_lines:] + + buffers = int(buffers * max(2, int(num_lines / max(cur_lines, 1)))) # Adjust buffer size + + def isfile(self, remote_file): + return os.path.isfile(remote_file) + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = f'kill -{signal} {pid}' + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return os.getpid() + + # Database control + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + conn = pglib.connect( + host=host, + port=port, + database=dbname, + user=user, + password=password, + ) + return conn diff --git a/testgres/op_ops/os_ops.py b/testgres/op_ops/os_ops.py new file mode 100644 index 00000000..89de2640 --- /dev/null +++ b/testgres/op_ops/os_ops.py @@ -0,0 +1,99 @@ +try: + import psycopg2 as pglib +except ImportError: + try: + import pg8000 as pglib + except ImportError: + raise ImportError("You must have psycopg2 or pg8000 modules installed") + +from testgres.defaults import default_username + + +class OsOperations: + + def __init__(self, username=None): + self.hostname = 'localhost' + self.remote = False + self.ssh = None + self.username = username + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + raise NotImplementedError() + + # Environment setup + def environ(self, var_name): + raise NotImplementedError() + + def find_executable(self, executable): + raise NotImplementedError() + + def is_executable(self, file): + # Check if the file is executable + raise NotImplementedError() + + def add_to_path(self, new_path): + raise NotImplementedError() + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + raise NotImplementedError() + + # Get environment variables + def get_user(self): + raise NotImplementedError() + + def get_name(self): + raise NotImplementedError() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + raise NotImplementedError() + + def rmdirs(self, path, ignore_errors=True): + raise NotImplementedError() + + def listdir(self, path): + raise NotImplementedError() + + def path_exists(self, path): + raise NotImplementedError() + + @property + def pathsep(self): + raise NotImplementedError() + + def mkdtemp(self, prefix=None): + raise NotImplementedError() + + def copytree(self, src, dst): + raise NotImplementedError() + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + raise NotImplementedError() + + def touch(self, filename): + raise NotImplementedError() + + def read(self, filename): + raise NotImplementedError() + + def readlines(self, filename): + raise NotImplementedError() + + def isfile(self, remote_file): + raise NotImplementedError() + + # Processes control + def kill(self, pid, signal): + # Kill the process + raise NotImplementedError() + + def get_pid(self): + # Get current process id + raise NotImplementedError() + + # Database control + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + raise NotImplementedError() diff --git a/testgres/op_ops/remote_ops.py b/testgres/op_ops/remote_ops.py new file mode 100644 index 00000000..d5faab4e --- /dev/null +++ b/testgres/op_ops/remote_ops.py @@ -0,0 +1,259 @@ +import os +import tempfile +from contextlib import contextmanager + +from testgres.logger import log + +from .os_ops import OsOperations +from .os_ops import pglib + +import paramiko + + +class RemoteOperations(OsOperations): + """ + This class specifically supports work with Linux systems. It utilizes the SSH + for making connections and performing various file and directory operations, command executions, + environment setup and management, process control, and database connections. + It uses the Paramiko library for SSH connections and operations. + + Some methods are designed to work with specific Linux shell commands, and thus may not work as expected + on other non-Linux systems. + + Attributes: + - hostname (str): The remote system's hostname. Default 'localhost'. + - host (str): The remote system's IP address. Default '127.0.0.1'. + - ssh_key (str): Path to the SSH private key for authentication. + - username (str): Username for the remote system. + - ssh (paramiko.SSHClient): SSH connection to the remote system. + """ + + def __init__(self, hostname='localhost', host='127.0.0.1', ssh_key=None, username=None): + super().__init__(username) + self.hostname = hostname + self.host = host + self.ssh_key = ssh_key + self.remote = True + self.ssh = self.connect() + self.username = username or self.get_user() + + def __del__(self): + if self.ssh: + self.ssh.close() + + @contextmanager + def ssh_connect(self): + if not self.remote: + yield None + else: + with open(self.ssh_key, 'r') as f: + key_data = f.read() + if 'BEGIN OPENSSH PRIVATE KEY' in key_data: + key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) + else: + key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + + with paramiko.SSHClient() as ssh: + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + yield ssh + + def connect(self): + with self.ssh_connect() as ssh: + return ssh + + # Command execution + def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding='utf-8'): + if isinstance(cmd, list): + cmd = ' '.join(cmd) + log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") + # Source global profile file + execute command + try: + cmd = f"source /etc/profile.d/custom.sh; {cmd}" + with self.ssh_connect() as ssh: + stdin, stdout, stderr = ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + + if expect_error: + raise Exception(result, error) + if exit_status != 0 or 'error' in error.lower(): + log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + exit(1) + + if verbose: + return exit_status, result, error + else: + return result + + except Exception as e: + log.error(f"Unexpected error while executing command `{cmd}`: {e}") + return None + + # Environment setup + def environ(self, var_name): + cmd = f"echo ${var_name}" + return self.exec_command(cmd).strip() + + def find_executable(self, executable): + search_paths = self.environ('PATH') + if not search_paths: + return None + + search_paths = search_paths.split(self.pathsep) + for path in search_paths: + remote_file = os.path.join(path, executable) + if self.isfile(remote_file): + return remote_file + + return None + + def is_executable(self, file): + # Check if the file is executable + return self.exec_command(f"test -x {file} && echo OK") == 'OK\n' + + def add_to_path(self, new_path): + pathsep = self.pathsep + # Check if the directory is already in PATH + path = self.environ('PATH') + if new_path not in path.split(pathsep): + if self.remote: + self.exec_command(f"export PATH={new_path}{pathsep}{path}") + else: + os.environ['PATH'] = f"{new_path}{pathsep}{path}" + return pathsep + + def set_env(self, var_name, var_val): + # Check if the directory is already in PATH + return self.exec_command(f"export {var_name}={var_val}") + + # Get environment variables + def get_user(self): + return self.exec_command(f"echo $USER") + + def get_name(self): + cmd = 'python3 -c "import os; print(os.name)"' + return self.exec_command(cmd).strip() + + # Work with dirs + def makedirs(self, path, remove_existing=False): + if remove_existing: + cmd = f'rm -rf {path} && mkdir -p {path}' + else: + cmd = f'mkdir -p {path}' + return self.exec_command(cmd) + + def rmdirs(self, path, ignore_errors=True): + cmd = f'rm -rf {path}' + return self.exec_command(cmd) + + def listdir(self, path): + result = self.exec_command(f'ls {path}') + return result.splitlines() + + def path_exists(self, path): + result = self.exec_command(f'test -e {path}; echo $?') + return int(result.strip()) == 0 + + @property + def pathsep(self): + os_name = self.get_name() + if os_name == 'posix': + pathsep = ':' + elif os_name == 'nt': + pathsep = ';' + else: + raise Exception(f"Unsupported operating system: {os_name}") + return pathsep + + def mkdtemp(self, prefix=None): + temp_dir = self.exec_command(f'mkdtemp -d {prefix}') + return temp_dir.strip() + + def copytree(self, src, dst): + return self.exec_command(f'cp -r {src} {dst}') + + # Work with files + def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + """ + Write data to a file on a remote host + Args: + filename: The file path where the data will be written. + data: The data to be written to the file. + truncate: If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + binary: If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + read_and_write: If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + """ + mode = 'wb' if binary else 'w' + if not truncate: + mode = 'a' + mode + if read_and_write: + mode = 'r+' + mode + + with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: + if isinstance(data, list): + tmp_file.writelines(data) + else: + tmp_file.write(data) + tmp_file.flush() + + sftp = self.ssh.open_sftp() + sftp.put(tmp_file.name, filename) + sftp.close() + + def touch(self, filename): + """ + Create a new file or update the access and modification times of an existing file on the remote server. + + Args: + filename (str): The name of the file to touch. + + This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. + """ + self.exec_command(f'touch {filename}') + + def read(self, filename, encoding='utf-8'): + cmd = f'cat {filename}' + return self.exec_command(cmd, encoding=encoding) + + def readlines(self, filename, num_lines=0, encoding=None): + encoding = encoding or 'utf-8' + if num_lines > 0: + cmd = f'tail -n {num_lines} {filename}' + lines = self.exec_command(cmd, encoding) + else: + lines = self.read(filename, encoding=encoding).splitlines() + return lines + + def isfile(self, remote_file): + stdout = self.exec_command(f'test -f {remote_file}; echo $?') + result = int(stdout.strip()) + return result == 0 + + # Processes control + def kill(self, pid, signal): + # Kill the process + cmd = f'kill -{signal} {pid}' + return self.exec_command(cmd) + + def get_pid(self): + # Get current process id + return self.exec_command(f"echo $$") + + # Database control + def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + local_port = self.ssh.forward_remote_port(host, port) + conn = pglib.connect( + host=host, + port=local_port, + database=dbname, + user=user, + password=password, + ) + return conn diff --git a/testgres/os_ops.py b/testgres/os_ops.py deleted file mode 100644 index 0be8c2a7..00000000 --- a/testgres/os_ops.py +++ /dev/null @@ -1,285 +0,0 @@ -import getpass -import os -import shutil -import subprocess -import tempfile -from contextlib import contextmanager -from shutil import rmtree - -try: - import psycopg2 as pglib -except ImportError: - try: - import pg8000 as pglib - except ImportError: - raise ImportError("You must have psycopg2 or pg8000 modules installed") - -from testgres.defaults import default_username -from testgres.logger import log - -import paramiko - - -class OsOperations: - - def __init__(self, host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): - self.host = host - self.ssh_key = ssh_key - self.username = username - self.remote = not (self.host == '127.0.0.1' and hostname == 'localhost') - self.ssh = None - - if self.remote: - self.ssh = self.connect() - - def __del__(self): - if self.ssh: - self.ssh.close() - - @contextmanager - def ssh_connect(self): - if not self.remote: - yield None - else: - with open(self.ssh_key, 'r') as f: - key_data = f.read() - if 'BEGIN OPENSSH PRIVATE KEY' in key_data: - key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) - else: - key = paramiko.RSAKey.from_private_key_file(self.ssh_key) - - with paramiko.SSHClient() as ssh: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - yield ssh - - def connect(self): - with self.ssh_connect() as ssh: - return ssh - - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): - if isinstance(cmd, list): - cmd = ' '.join(cmd) - log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") - # Source global profile file + execute command - try: - if self.remote: - cmd = f"source /etc/profile.d/custom.sh; {cmd}" - with self.ssh_connect() as ssh: - stdin, stdout, stderr = ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() - result = stdout.read().decode('utf-8') - error = stderr.read().decode('utf-8') - else: - process = subprocess.run(cmd, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - timeout=60) - exit_status = process.returncode - result = process.stdout - error = process.stderr - - if expect_error: - raise Exception(result, error) - if exit_status != 0 or 'error' in error.lower(): - log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") - exit(1) - - if verbose: - return exit_status, result, error - else: - return result - - except Exception as e: - log.error(f"Unexpected error while executing command `{cmd}`: {e}") - return None - - def makedirs(self, path, remove_existing=False): - if remove_existing: - cmd = f'rm -rf {path} && mkdir -p {path}' - else: - cmd = f'mkdir -p {path}' - self.exec_command(cmd) - - def rmdirs(self, path, ignore_errors=True): - if self.remote: - cmd = f'rm -rf {path}' - self.exec_command(cmd) - else: - rmtree(path, ignore_errors=ignore_errors) - - def mkdtemp(self, prefix=None): - if self.remote: - temp_dir = self.exec_command(f'mkdtemp -d {prefix}') - return temp_dir.strip() - else: - return tempfile.mkdtemp(prefix=prefix) - - def path_exists(self, path): - if self.remote: - result = self.exec_command(f'test -e {path}; echo $?') - return int(result.strip()) == 0 - else: - return os.path.exists(path) - - def copytree(self, src, dst): - if self.remote: - self.exec_command(f'cp -r {src} {dst}') - else: - shutil.copytree(src, dst) - - def listdir(self, path): - if self.remote: - result = self.exec_command(f'ls {path}') - return result.splitlines() - else: - return os.listdir(path) - - def write(self, filename, data, truncate=False, binary=False, read_and_write=False): - """ - Write data to a file, both locally and on a remote host. - - :param filename: The file path where the data will be written. - :param data: The data to be written to the file. - :param truncate: If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - :param binary: If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - :param read_and_write: If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). - """ - mode = 'wb' if binary else 'w' - if not truncate: - mode = 'a' + mode - if read_and_write: - mode = 'r+' + mode - - if self.remote: - with tempfile.NamedTemporaryFile() as tmp_file: - tmp_file.write(data) - tmp_file.flush() - - sftp = self.ssh.open_sftp() - sftp.put(tmp_file.name, filename) - sftp.close() - else: - with open(filename, mode) as file: - file.write(data) - - def read(self, filename): - cmd = f'cat {filename}' - return self.exec_command(cmd) - - def readlines(self, filename): - return self.read(filename).splitlines() - - def get_name(self): - cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd).strip() - - def kill(self, pid, signal): - cmd = f'kill -{signal} {pid}' - self.exec_command(cmd) - - def environ(self, var_name): - cmd = f"echo ${var_name}" - return self.exec_command(cmd).strip() - - @property - def pathsep(self): - return ':' if self.get_name() == 'posix' else ';' - - def isfile(self, remote_file): - if self.remote: - stdout = self.exec_command(f'test -f {remote_file}; echo $?') - result = int(stdout.strip()) - return result == 0 - else: - return os.path.isfile(remote_file) - - def find_executable(self, executable): - search_paths = self.environ('PATH') - if not search_paths: - return None - - search_paths = search_paths.split(self.pathsep) - for path in search_paths: - remote_file = os.path.join(path, executable) - if self.isfile(remote_file): - return remote_file - - return None - - def is_executable(self, file): - # Check if the file is executable - if self.remote: - if not self.exec_command(f"test -x {file} && echo OK") == 'OK\n': - return False - else: - if not os.access(file, os.X_OK): - return False - return True - - def add_to_path(self, new_path): - os_name = self.get_name() - if os_name == 'posix': - dir_del = ':' - elif os_name == 'nt': - dir_del = ';' - else: - raise Exception(f"Unsupported operating system: {os_name}") - - # Check if the directory is already in PATH - path = self.environ('PATH') - if new_path not in path.split(dir_del): - if self.remote: - self.exec_command(f"export PATH={new_path}{dir_del}{path}") - else: - os.environ['PATH'] = f"{new_path}{dir_del}{path}" - return dir_del - - def set_env(self, var_name, var_val): - # Check if the directory is already in PATH - if self.remote: - self.exec_command(f"export {var_name}={var_val}") - else: - os.environ[var_name] = var_val - - def get_pid(self): - # Get current process id - if self.remote: - process_id = self.exec_command(f"echo $$") - else: - process_id = os.getpid() - return process_id - - def get_user(self): - # Get current user - if self.remote: - user = self.exec_command(f"echo $USER") - else: - user = getpass.getuser() - return user - - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): - if self.remote: - local_port = self.ssh.forward_remote_port(host, port) - conn = pglib.connect( - host=host, - port=local_port, - dbname=dbname, - user=user, - password=password, - ) - else: - conn = pglib.connect( - host=host, - port=port, - dbname=dbname, - user=user, - password=password, - ) - return conn - - - diff --git a/testgres/utils.py b/testgres/utils.py index b27fb6b8..73ca6f1a 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -19,6 +19,8 @@ from six import iteritems from fabric import Connection +from .op_ops.local_ops import LocalOperations +from .op_ops.os_ops import OsOperations from .config import testgres_config from .exceptions import ExecUtilException @@ -52,7 +54,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None, hostname='localhost', ssh_key=None): +def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations()): """ Execute utility (pg_ctl, pg_dump etc). @@ -64,11 +66,11 @@ def execute_utility(args, logfile=None, hostname='localhost', ssh_key=None): stdout of executed utility. """ - if hostname != 'localhost': + if os_ops.hostname != 'localhost': conn = Connection( - hostname, + os_ops.hostname, connect_kwargs={ - "key_filename": f"{ssh_key}", + "key_filename": f"{os_ops.ssh_key}", }, ) From ac77ef78f640b9b518817403c6841b25e8f46e9b Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 12 Jun 2023 00:31:28 +0200 Subject: [PATCH 08/23] PBCKP-152 use black for formatting --- testgres/cache.py | 7 +- testgres/config.py | 9 ++- testgres/defaults.py | 3 +- testgres/node.py | 42 +++++------- testgres/{op_ops => os_ops}/local_ops.py | 62 ++++++++++------- testgres/{op_ops => os_ops}/os_ops.py | 13 ++-- testgres/{op_ops => os_ops}/remote_ops.py | 81 +++++++++++++---------- testgres/utils.py | 20 +++--- 8 files changed, 127 insertions(+), 110 deletions(-) rename testgres/{op_ops => os_ops}/local_ops.py (80%) rename testgres/{op_ops => os_ops}/os_ops.py (90%) rename testgres/{op_ops => os_ops}/remote_ops.py (79%) diff --git a/testgres/cache.py b/testgres/cache.py index 4998e0d2..1df5a8ea 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -4,8 +4,8 @@ from six import raise_from -from .op_ops.local_ops import LocalOperations -from .op_ops.os_ops import OsOperations +from .os_ops.local_ops import LocalOperations +from .os_ops.os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -25,6 +25,7 @@ def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = Lo """ Perform initdb or use cached node files. """ + testgres_config.os_ops = os_ops def call_initdb(initdb_dir, log=None): try: @@ -60,7 +61,7 @@ def call_initdb(initdb_dir, log=None): # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile, os_ops) + execute_utility(_params, logfile, os_ops=os_ops) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index 1be76fbe..fd942664 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -5,9 +5,8 @@ import tempfile from contextlib import contextmanager -from shutil import rmtree -from .op_ops.local_ops import LocalOperations +from .os_ops.local_ops import LocalOperations from .consts import TMP_CACHE @@ -44,7 +43,7 @@ class GlobalConfig(object): _cached_initdb_dir = None """ underlying class attribute for cached_initdb_dir property """ - os_ops = None + os_ops = LocalOperations() """ OsOperation object that allows work on remote host """ @property def cached_initdb_dir(self): @@ -137,9 +136,9 @@ def copy(self): @atexit.register -def _rm_cached_initdb_dirs(os_ops=LocalOperations()): +def _rm_cached_initdb_dirs(): for d in cached_initdb_dirs: - os_ops.rmdirs(d, ignore_errors=True) + testgres_config.os_ops.rmdirs(d, ignore_errors=True) def push_config(**options): diff --git a/testgres/defaults.py b/testgres/defaults.py index cac788b8..5ffc08de 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -1,9 +1,8 @@ import datetime -import os import struct import uuid -from .op_ops.local_ops import LocalOperations +from .os_ops.local_ops import LocalOperations def default_dbname(): diff --git a/testgres/node.py b/testgres/node.py index f06e5cc9..6456e5a9 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1,20 +1,16 @@ # coding: utf-8 -import io import os import random -import shutil import signal import threading from queue import Queue import psutil -import subprocess import time -from op_ops.local_ops import LocalOperations -from op_ops.os_ops import OsOperations -from op_ops.remote_ops import RemoteOperations +from .os_ops.local_ops import LocalOperations +from .os_ops.remote_ops import RemoteOperations try: from collections.abc import Iterable @@ -32,7 +28,6 @@ from shutil import rmtree from six import raise_from, iteritems, text_type -from tempfile import mkstemp, mkdtemp from .enums import \ NodeStatus, \ @@ -96,7 +91,6 @@ eprint, \ get_bin_path, \ get_pg_version, \ - file_tail, \ reserve_port, \ release_port, \ execute_utility, \ @@ -163,6 +157,7 @@ def __init__(self, name=None, port=None, base_dir=None, else: self.os_ops = RemoteOperations(host, hostname, ssh_key) + testgres_config.os_ops = self.os_ops # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit @@ -289,7 +284,7 @@ def base_dir(self): self._base_dir = self.os_ops.mkdtemp(prefix=TMP_NODE) # NOTE: it's safe to create a new dir - if not self.os_ops.exists(self._base_dir): + if not self.os_ops.path_exists(self._base_dir): self.os_ops.makedirs(self._base_dir) return self._base_dir @@ -299,7 +294,7 @@ def logs_dir(self): path = os.path.join(self.base_dir, LOGS_DIR) # NOTE: it's safe to create a new dir - if not self.os_ops.exists(path): + if not self.os_ops.path_exists(path): self.os_ops.makedirs(path) return path @@ -628,7 +623,7 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) return NodeStatus.Running except ExecUtilException as e: @@ -650,7 +645,7 @@ def get_control_data(self): _params += ["-D"] if self._pg_version >= PgVer('9.5') else [] _params += [self.data_dir] - data = execute_utility(_params, self.utils_log_file) + data = execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) out_dict = {} @@ -713,7 +708,7 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) except ExecUtilException as e: msg = 'Cannot start node' files = self._collect_special_files() @@ -744,7 +739,7 @@ def stop(self, params=[], wait=True): "stop" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) self._maybe_stop_logger() self.is_started = False @@ -786,7 +781,7 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() @@ -813,7 +808,7 @@ def reload(self, params=[]): "reload" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) return self @@ -835,7 +830,7 @@ def promote(self, dbname=None, username=None): "promote" ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) # for versions below 10 `promote` is asynchronous so we need to wait # until it actually becomes writable @@ -870,7 +865,7 @@ def pg_ctl(self, params): "-w" # wait ] + params # yapf: disable - return execute_utility(_params, self.utils_log_file) + return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) def free_port(self): """ @@ -1035,10 +1030,9 @@ def dump(self, # Generate tmpfile or tmpdir def tmpfile(): if format == DumpFormat.Directory: - fname = mkdtemp(prefix=TMP_DUMP) + fname = self.os_ops.mkdtemp(prefix=TMP_DUMP) else: - fd, fname = mkstemp(prefix=TMP_DUMP) - os.close(fd) + fname = self.os_ops.mkstemp(prefix=TMP_DUMP) return fname # Set default arguments @@ -1056,7 +1050,7 @@ def tmpfile(): "-F", format.value ] # yapf: disable - execute_utility(_params, self.utils_log_file) + execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) return filename @@ -1085,7 +1079,7 @@ def restore(self, filename, dbname=None, username=None): # try pg_restore if dump is binary formate, and psql if not try: - execute_utility(_params, self.utils_log_name) + execute_utility(_params, self.utils_log_name, os_ops=self.os_ops) except ExecUtilException: self.psql(filename=filename, dbname=dbname, username=username) @@ -1417,7 +1411,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # should be the last one _params.append(dbname) - return execute_utility(_params, self.utils_log_file) + return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) def connect(self, dbname=None, diff --git a/testgres/op_ops/local_ops.py b/testgres/os_ops/local_ops.py similarity index 80% rename from testgres/op_ops/local_ops.py rename to testgres/os_ops/local_ops.py index 42a3b4b7..d6977c9f 100644 --- a/testgres/op_ops/local_ops.py +++ b/testgres/os_ops/local_ops.py @@ -14,7 +14,6 @@ class LocalOperations(OsOperations): - def __init__(self, username=None): super().__init__() self.username = username or self.get_user() @@ -22,22 +21,28 @@ def __init__(self, username=None): # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): if isinstance(cmd, list): - cmd = ' '.join(cmd) + cmd = " ".join(cmd) log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") # Source global profile file + execute command try: - process = subprocess.run(cmd, shell=True, text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - timeout=CMD_TIMEOUT_SEC) + process = subprocess.run( + cmd, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + timeout=CMD_TIMEOUT_SEC, + ) exit_status = process.returncode result = process.stdout error = process.stderr if expect_error: raise Exception(result, error) - if exit_status != 0 or 'error' in error.lower(): - log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + if exit_status != 0 or "error" in error.lower(): + log.error( + f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" + ) exit(1) if verbose: @@ -55,7 +60,7 @@ def environ(self, var_name): return self.exec_command(cmd).strip() def find_executable(self, executable): - search_paths = self.environ('PATH') + search_paths = self.environ("PATH") if not search_paths: return None @@ -74,12 +79,12 @@ def is_executable(self, file): def add_to_path(self, new_path): pathsep = self.pathsep # Check if the directory is already in PATH - path = self.environ('PATH') + path = self.environ("PATH") if new_path not in path.split(pathsep): if self.remote: self.exec_command(f"export PATH={new_path}{pathsep}{path}") else: - os.environ['PATH'] = f"{new_path}{pathsep}{path}" + os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep def set_env(self, var_name, var_val): @@ -112,10 +117,10 @@ def path_exists(self, path): @property def pathsep(self): os_name = self.get_name() - if os_name == 'posix': - pathsep = ':' - elif os_name == 'nt': - pathsep = ';' + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" else: raise Exception(f"Unsupported operating system: {os_name}") return pathsep @@ -123,6 +128,11 @@ def pathsep(self): def mkdtemp(self, prefix=None): return tempfile.mkdtemp(prefix=prefix) + def mkstemp(self, prefix=None): + fd, filename = tempfile.mkstemp(prefix=prefix) + os.close(fd) # Close the file descriptor immediately after creating the file + return filename + def copytree(self, src, dst): return shutil.copytree(src, dst) @@ -140,11 +150,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal read_and_write: If True, the file will be opened with read and write permissions ('r+' option); if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) """ - mode = 'wb' if binary else 'w' + mode = "wb" if binary else "w" if not truncate: - mode = 'a' + mode + mode = "a" + mode if read_and_write: - mode = 'r+' + mode + mode = "r+" + mode with open(filename, mode) as file: if isinstance(data, list): @@ -161,11 +171,11 @@ def touch(self, filename): This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. """ # cross-python touch(). It is vulnerable to races, but who cares? - with open(filename, 'a'): + with open(filename, "a"): os.utime(filename, None) def read(self, filename): - with open(filename, 'r') as file: + with open(filename, "r") as file: return file.read() def readlines(self, filename, num_lines=0, encoding=None): @@ -176,14 +186,14 @@ def readlines(self, filename, num_lines=0, encoding=None): assert num_lines >= 0 if num_lines == 0: - with open(filename, 'r', encoding=encoding) as file: + with open(filename, "r", encoding=encoding) as file: return file.readlines() else: bufsize = 8192 buffers = 1 - with open(filename, 'r', encoding=encoding) as file: + with open(filename, "r", encoding=encoding) as file: file.seek(0, os.SEEK_END) end_pos = file.tell() @@ -197,7 +207,9 @@ def readlines(self, filename, num_lines=0, encoding=None): if cur_lines >= num_lines or pos == 0: return lines[-num_lines:] - buffers = int(buffers * max(2, int(num_lines / max(cur_lines, 1)))) # Adjust buffer size + buffers = int( + buffers * max(2, int(num_lines / max(cur_lines, 1))) + ) # Adjust buffer size def isfile(self, remote_file): return os.path.isfile(remote_file) @@ -205,7 +217,7 @@ def isfile(self, remote_file): # Processes control def kill(self, pid, signal): # Kill the process - cmd = f'kill -{signal} {pid}' + cmd = f"kill -{signal} {pid}" return self.exec_command(cmd) def get_pid(self): @@ -213,7 +225,7 @@ def get_pid(self): return os.getpid() # Database control - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): conn = pglib.connect( host=host, port=port, diff --git a/testgres/op_ops/os_ops.py b/testgres/os_ops/os_ops.py similarity index 90% rename from testgres/op_ops/os_ops.py rename to testgres/os_ops/os_ops.py index 89de2640..1ee1f869 100644 --- a/testgres/op_ops/os_ops.py +++ b/testgres/os_ops/os_ops.py @@ -1,18 +1,15 @@ try: - import psycopg2 as pglib + import psycopg2 as pglib # noqa: F401 except ImportError: try: - import pg8000 as pglib + import pg8000 as pglib # noqa: F401 except ImportError: raise ImportError("You must have psycopg2 or pg8000 modules installed") -from testgres.defaults import default_username - class OsOperations: - def __init__(self, username=None): - self.hostname = 'localhost' + self.hostname = "localhost" self.remote = False self.ssh = None self.username = username @@ -49,7 +46,7 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): raise NotImplementedError() - + def rmdirs(self, path, ignore_errors=True): raise NotImplementedError() @@ -95,5 +92,5 @@ def get_pid(self): raise NotImplementedError() # Database control - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): raise NotImplementedError() diff --git a/testgres/op_ops/remote_ops.py b/testgres/os_ops/remote_ops.py similarity index 79% rename from testgres/op_ops/remote_ops.py rename to testgres/os_ops/remote_ops.py index d5faab4e..e1460b75 100644 --- a/testgres/op_ops/remote_ops.py +++ b/testgres/os_ops/remote_ops.py @@ -28,7 +28,9 @@ class RemoteOperations(OsOperations): - ssh (paramiko.SSHClient): SSH connection to the remote system. """ - def __init__(self, hostname='localhost', host='127.0.0.1', ssh_key=None, username=None): + def __init__( + self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None + ): super().__init__(username) self.hostname = hostname self.host = host @@ -46,9 +48,9 @@ def ssh_connect(self): if not self.remote: yield None else: - with open(self.ssh_key, 'r') as f: + with open(self.ssh_key, "r") as f: key_data = f.read() - if 'BEGIN OPENSSH PRIVATE KEY' in key_data: + if "BEGIN OPENSSH PRIVATE KEY" in key_data: key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) else: key = paramiko.RSAKey.from_private_key_file(self.ssh_key) @@ -63,9 +65,11 @@ def connect(self): return ssh # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding='utf-8'): + def exec_command( + self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding="utf-8" + ): if isinstance(cmd, list): - cmd = ' '.join(cmd) + cmd = " ".join(cmd) log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") # Source global profile file + execute command try: @@ -80,8 +84,10 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, if expect_error: raise Exception(result, error) - if exit_status != 0 or 'error' in error.lower(): - log.error(f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}") + if exit_status != 0 or "error" in error.lower(): + log.error( + f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" + ) exit(1) if verbose: @@ -99,7 +105,7 @@ def environ(self, var_name): return self.exec_command(cmd).strip() def find_executable(self, executable): - search_paths = self.environ('PATH') + search_paths = self.environ("PATH") if not search_paths: return None @@ -113,17 +119,17 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - return self.exec_command(f"test -x {file} && echo OK") == 'OK\n' + return self.exec_command(f"test -x {file} && echo OK") == "OK\n" def add_to_path(self, new_path): pathsep = self.pathsep # Check if the directory is already in PATH - path = self.environ('PATH') + path = self.environ("PATH") if new_path not in path.split(pathsep): if self.remote: self.exec_command(f"export PATH={new_path}{pathsep}{path}") else: - os.environ['PATH'] = f"{new_path}{pathsep}{path}" + os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep def set_env(self, var_name, var_val): @@ -132,7 +138,7 @@ def set_env(self, var_name, var_val): # Get environment variables def get_user(self): - return self.exec_command(f"echo $USER") + return self.exec_command("echo $USER") def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' @@ -141,40 +147,45 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): if remove_existing: - cmd = f'rm -rf {path} && mkdir -p {path}' + cmd = f"rm -rf {path} && mkdir -p {path}" else: - cmd = f'mkdir -p {path}' + cmd = f"mkdir -p {path}" return self.exec_command(cmd) def rmdirs(self, path, ignore_errors=True): - cmd = f'rm -rf {path}' + cmd = f"rm -rf {path}" return self.exec_command(cmd) def listdir(self, path): - result = self.exec_command(f'ls {path}') + result = self.exec_command(f"ls {path}") return result.splitlines() def path_exists(self, path): - result = self.exec_command(f'test -e {path}; echo $?') + result = self.exec_command(f"test -e {path}; echo $?") return int(result.strip()) == 0 @property def pathsep(self): os_name = self.get_name() - if os_name == 'posix': - pathsep = ':' - elif os_name == 'nt': - pathsep = ';' + if os_name == "posix": + pathsep = ":" + elif os_name == "nt": + pathsep = ";" else: raise Exception(f"Unsupported operating system: {os_name}") return pathsep def mkdtemp(self, prefix=None): - temp_dir = self.exec_command(f'mkdtemp -d {prefix}') + temp_dir = self.exec_command(f"mkdtemp -d {prefix}") return temp_dir.strip() + def mkstemp(self, prefix=None): + cmd = f"mktemp {prefix}XXXXXX" + filename = self.exec_command(cmd).strip() + return filename + def copytree(self, src, dst): - return self.exec_command(f'cp -r {src} {dst}') + return self.exec_command(f"cp -r {src} {dst}") # Work with files def write(self, filename, data, truncate=False, binary=False, read_and_write=False): @@ -190,11 +201,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal read_and_write: If True, the file will be opened with read and write permissions ('r+' option); if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) """ - mode = 'wb' if binary else 'w' + mode = "wb" if binary else "w" if not truncate: - mode = 'a' + mode + mode = "a" + mode if read_and_write: - mode = 'r+' + mode + mode = "r+" + mode with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: if isinstance(data, list): @@ -216,38 +227,38 @@ def touch(self, filename): This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. """ - self.exec_command(f'touch {filename}') + self.exec_command(f"touch {filename}") - def read(self, filename, encoding='utf-8'): - cmd = f'cat {filename}' + def read(self, filename, encoding="utf-8"): + cmd = f"cat {filename}" return self.exec_command(cmd, encoding=encoding) def readlines(self, filename, num_lines=0, encoding=None): - encoding = encoding or 'utf-8' + encoding = encoding or "utf-8" if num_lines > 0: - cmd = f'tail -n {num_lines} {filename}' + cmd = f"tail -n {num_lines} {filename}" lines = self.exec_command(cmd, encoding) else: lines = self.read(filename, encoding=encoding).splitlines() return lines def isfile(self, remote_file): - stdout = self.exec_command(f'test -f {remote_file}; echo $?') + stdout = self.exec_command(f"test -f {remote_file}; echo $?") result = int(stdout.strip()) return result == 0 # Processes control def kill(self, pid, signal): # Kill the process - cmd = f'kill -{signal} {pid}' + cmd = f"kill -{signal} {pid}" return self.exec_command(cmd) def get_pid(self): # Get current process id - return self.exec_command(f"echo $$") + return self.exec_command("echo $$") # Database control - def db_connect(self, dbname, user, password=None, host='localhost', port=5432): + def db_connect(self, dbname, user, password=None, host="localhost", port=5432): local_port = self.ssh.forward_remote_port(host, port) conn = pglib.connect( host=host, diff --git a/testgres/utils.py b/testgres/utils.py index 73ca6f1a..72fd1b9d 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -12,6 +12,9 @@ from contextlib import contextmanager from packaging.version import Version + +from .os_ops.remote_ops import RemoteOperations + try: from shutil import which as find_executable except ImportError: @@ -19,8 +22,8 @@ from six import iteritems from fabric import Connection -from .op_ops.local_ops import LocalOperations -from .op_ops.os_ops import OsOperations +from .os_ops.local_ops import LocalOperations +from .os_ops.os_ops import OsOperations from .config import testgres_config from .exceptions import ExecUtilException @@ -59,6 +62,7 @@ def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations() Execute utility (pg_ctl, pg_dump etc). Args: + os_ops: LocalOperations for local node or RemoteOperations for node that connected by ssh. args: utility + arguments (list). logfile: path to file to store stdout and stderr. @@ -66,21 +70,20 @@ def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations() stdout of executed utility. """ - if os_ops.hostname != 'localhost': + if isinstance(os_ops, RemoteOperations): conn = Connection( os_ops.hostname, connect_kwargs={ "key_filename": f"{os_ops.ssh_key}", }, ) - # TODO skip remote ssh run if we are on the localhost. # result = conn.run('hostname', hide=True) - # add logger + # add logger cmd = ' '.join(args) - result = conn.run(cmd, hide=True) - + result = conn.run(cmd, hide=True) + return result # run utility @@ -173,8 +176,9 @@ def get_bin_path(filename): def get_pg_config(pg_config_path=None): """ Return output of pg_config (provided that it is installed). - NOTE: this fuction caches the result by default (see GlobalConfig). + NOTE: this function caches the result by default (see GlobalConfig). """ + def cache_pg_config_data(cmd): # execute pg_config and get the output out = subprocess.check_output([cmd]).decode('utf-8') From b04804102457adf6ef29646e9dd86ceaccb24127 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 12 Jun 2023 23:08:10 +0200 Subject: [PATCH 09/23] PBCKP-152 fix failed tests --- testgres/__init__.py | 5 ++ testgres/cache.py | 7 +- testgres/config.py | 2 +- testgres/connection.py | 24 ++---- testgres/defaults.py | 2 +- testgres/node.py | 21 +++-- testgres/operations/__init__.py | 0 testgres/{os_ops => operations}/local_ops.py | 45 ++++++----- testgres/{os_ops => operations}/os_ops.py | 0 testgres/{os_ops => operations}/remote_ops.py | 55 +++++++++---- testgres/utils.py | 8 +- tests/test_remote.py | 81 +++++++++++++++++++ tests/test_simple.py | 6 +- 13 files changed, 181 insertions(+), 75 deletions(-) create mode 100644 testgres/operations/__init__.py rename testgres/{os_ops => operations}/local_ops.py (82%) rename testgres/{os_ops => operations}/os_ops.py (100%) rename testgres/{os_ops => operations}/remote_ops.py (84%) create mode 100755 tests/test_remote.py diff --git a/testgres/__init__.py b/testgres/__init__.py index 1b33ba3b..405262dd 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -46,6 +46,10 @@ First, \ Any +from .operations.os_ops import OsOperations +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + __all__ = [ "get_new_node", "NodeBackup", @@ -56,4 +60,5 @@ "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", + "OsOperations", "LocalOperations", "RemoteOperations" ] diff --git a/testgres/cache.py b/testgres/cache.py index 1df5a8ea..ef07e976 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -4,8 +4,6 @@ from six import raise_from -from .os_ops.local_ops import LocalOperations -from .os_ops.os_ops import OsOperations from .config import testgres_config from .consts import XLOG_CONTROL_FILE @@ -20,6 +18,9 @@ get_bin_path, \ execute_utility +from .operations.local_ops import LocalOperations +from .operations.os_ops import OsOperations + def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = LocalOperations()): """ @@ -38,7 +39,7 @@ def call_initdb(initdb_dir, log=None): call_initdb(data_dir, logfile) else: # Fetch cached initdb dir - cached_data_dir = testgres_config.cached_initdb_dir() + cached_data_dir = testgres_config.cached_initdb_dir # Initialize cached initdb diff --git a/testgres/config.py b/testgres/config.py index fd942664..b21d8356 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -6,8 +6,8 @@ from contextlib import contextmanager -from .os_ops.local_ops import LocalOperations from .consts import TMP_CACHE +from .operations.local_ops import LocalOperations class GlobalConfig(object): diff --git a/testgres/connection.py b/testgres/connection.py index f3d01ebe..6725b14f 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -102,23 +102,15 @@ def rollback(self): return self def execute(self, query, *args): + self.cursor.execute(query, args) try: - with self.connection.cursor() as cursor: - cursor.execute(query, args) - try: - res = cursor.fetchall() - - # pg8000 might return tuples - if isinstance(res, tuple): - res = [tuple(t) for t in res] - - return res - except (pglib.ProgrammingError, pglib.InternalError) as e: - # An error occurred while trying to fetch results (e.g., no results to fetch) - print(f"Error fetching results: {e}") - return None - except (pglib.Error, Exception) as e: - # Handle other database errors + res = self.cursor.fetchall() + # pg8000 might return tuples + if isinstance(res, tuple): + res = [tuple(t) for t in res] + + return res + except Exception as e: print(f"Error executing query: {e}") return None diff --git a/testgres/defaults.py b/testgres/defaults.py index 5ffc08de..34bcc08b 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -2,7 +2,7 @@ import struct import uuid -from .os_ops.local_ops import LocalOperations +from .operations.local_ops import LocalOperations def default_dbname(): diff --git a/testgres/node.py b/testgres/node.py index 6456e5a9..9aa47d84 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -9,9 +9,6 @@ import psutil import time -from .os_ops.local_ops import LocalOperations -from .os_ops.remote_ops import RemoteOperations - try: from collections.abc import Iterable except ImportError: @@ -99,6 +96,9 @@ from .backup import NodeBackup +from .operations.local_ops import LocalOperations +from .operations.remote_ops import RemoteOperations + InternalError = pglib.InternalError ProgrammingError = pglib.ProgrammingError OperationalError = pglib.OperationalError @@ -201,7 +201,7 @@ def pid(self): if self.status(): pid_file = os.path.join(self.data_dir, PG_PID_FILE) - lines = self.os_ops.readlines(pid_file, num_lines=1) + lines = self.os_ops.readlines(pid_file) pid = int(lines[0]) if lines else None return pid @@ -433,7 +433,8 @@ def _collect_special_files(self): if not self.os_ops.path_exists(f): continue - lines = b''.join(self.os_ops.readlines(f, num_lines, encoding='utf-8')) + file_lines = self.os_ops.readlines(f, num_lines, binary=True, encoding=None) + lines = b''.join(file_lines) # fill list result.append((f, lines)) @@ -498,7 +499,7 @@ def default_conf(self, ] # write filtered lines - self.os_ops.write(hba_conf_file, lines, truncate=True) + self.os_ops.write(hba_conf, lines, truncate=True) # replication-related settings if allow_streaming: @@ -960,11 +961,9 @@ def psql(self, psql_params.append(dbname) # start psql process - process = self.os_ops.exec_command(psql_params) + status_code, out, err = self.os_ops.exec_command(psql_params, shell=False, verbose=True, input=input) - # wait until it finishes and get stdout and stderr - out, err = process.communicate(input=input) - return process.returncode, out, err + return status_code, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) def safe_psql(self, query=None, expect_error=False, **kwargs): @@ -1348,7 +1347,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = self.os_ops.exec_command(_params, wait_exit=True) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, shell=False, proc=True) return proc diff --git a/testgres/operations/__init__.py b/testgres/operations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/testgres/os_ops/local_ops.py b/testgres/operations/local_ops.py similarity index 82% rename from testgres/os_ops/local_ops.py rename to testgres/operations/local_ops.py index d6977c9f..acb10df8 100644 --- a/testgres/os_ops/local_ops.py +++ b/testgres/operations/local_ops.py @@ -19,18 +19,25 @@ def __init__(self, username=None): self.username = username or self.get_user() # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): - if isinstance(cmd, list): - cmd = " ".join(cmd) + def exec_command(self, cmd, wait_exit=False, verbose=False, + expect_error=False, encoding=None, shell=True, text=False, + input=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") # Source global profile file + execute command try: + if proc: + return subprocess.Popen(cmd, + shell=shell, + stdin=input or subprocess.PIPE, + stdout=stdout, + stderr=stderr) process = subprocess.run( cmd, - shell=True, - text=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + input=input, + shell=shell, + text=text, + stdout=stdout, + stderr=stderr, timeout=CMD_TIMEOUT_SEC, ) exit_status = process.returncode @@ -39,11 +46,11 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): if expect_error: raise Exception(result, error) - if exit_status != 0 or "error" in error.lower(): + if exit_status != 0 or "error" in error.lower().decode(encoding or 'utf-8'): # Decode error for comparison log.error( - f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" + f"Problem in executing command: `{cmd}`\nerror: {error.decode(encoding or 'utf-8')}\nexit_code: {exit_status}" + # Decode for logging ) - exit(1) if verbose: return exit_status, result, error @@ -152,9 +159,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal """ mode = "wb" if binary else "w" if not truncate: - mode = "a" + mode + mode = "ab" if binary else "a" if read_and_write: - mode = "r+" + mode + mode = "r+b" if binary else "r+" with open(filename, mode) as file: if isinstance(data, list): @@ -174,26 +181,26 @@ def touch(self, filename): with open(filename, "a"): os.utime(filename, None) - def read(self, filename): - with open(filename, "r") as file: + def read(self, filename, encoding=None): + with open(filename, "r", encoding=encoding) as file: return file.read() - def readlines(self, filename, num_lines=0, encoding=None): + def readlines(self, filename, num_lines=0, binary=False, encoding=None): """ Read lines from a local file. If num_lines is greater than 0, only the last num_lines lines will be read. """ assert num_lines >= 0 - + mode = 'rb' if binary else 'r' if num_lines == 0: - with open(filename, "r", encoding=encoding) as file: + with open(filename, mode, encoding=encoding) as file: # open in binary mode return file.readlines() else: bufsize = 8192 buffers = 1 - with open(filename, "r", encoding=encoding) as file: + with open(filename, mode, encoding=encoding) as file: # open in binary mode file.seek(0, os.SEEK_END) end_pos = file.tell() @@ -205,7 +212,7 @@ def readlines(self, filename, num_lines=0, encoding=None): cur_lines = len(lines) if cur_lines >= num_lines or pos == 0: - return lines[-num_lines:] + return lines[-num_lines:] # get last num_lines from lines buffers = int( buffers * max(2, int(num_lines / max(cur_lines, 1))) diff --git a/testgres/os_ops/os_ops.py b/testgres/operations/os_ops.py similarity index 100% rename from testgres/os_ops/os_ops.py rename to testgres/operations/os_ops.py diff --git a/testgres/os_ops/remote_ops.py b/testgres/operations/remote_ops.py similarity index 84% rename from testgres/os_ops/remote_ops.py rename to testgres/operations/remote_ops.py index e1460b75..dbe88dbe 100644 --- a/testgres/os_ops/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,3 +1,4 @@ +import io import os import tempfile from contextlib import contextmanager @@ -65,9 +66,9 @@ def connect(self): return ssh # Command execution - def exec_command( - self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding="utf-8" - ): + def exec_command(self, cmd, wait_exit=False, verbose=False, + expect_error=False, encoding=None, shell=True, text=False, + input=None, stdout=None, stderr=None, proc=None): if isinstance(cmd, list): cmd = " ".join(cmd) log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") @@ -75,20 +76,31 @@ def exec_command( try: cmd = f"source /etc/profile.d/custom.sh; {cmd}" with self.ssh_connect() as ssh: - stdin, stdout, stderr = ssh.exec_command(cmd) + if input: + # encode input and feed it to stdin + stdin, stdout, stderr = ssh.exec_command(cmd) + stdin.write(input) + stdin.flush() + else: + stdin, stdout, stderr = ssh.exec_command(cmd) exit_status = 0 if wait_exit: exit_status = stdout.channel.recv_exit_status() - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + # Save as binary string + result = io.BytesIO(stdout.read()).getvalue() + error = io.BytesIO(stderr.read()).getvalue() + error_str = stderr.read() if expect_error: raise Exception(result, error) - if exit_status != 0 or "error" in error.lower(): + if exit_status != 0 or 'error' in error_str: log.error( f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" ) - exit(1) if verbose: return exit_status, result, error @@ -203,9 +215,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal """ mode = "wb" if binary else "w" if not truncate: - mode = "a" + mode + mode = "ab" if binary else "a" if read_and_write: - mode = "r+" + mode + mode = "r+b" if binary else "r+" with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: if isinstance(data, list): @@ -229,17 +241,28 @@ def touch(self, filename): """ self.exec_command(f"touch {filename}") - def read(self, filename, encoding="utf-8"): + def read(self, filename, binary=False, encoding=None): cmd = f"cat {filename}" - return self.exec_command(cmd, encoding=encoding) + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + result = result.decode(encoding or 'utf-8') + + return result - def readlines(self, filename, num_lines=0, encoding=None): - encoding = encoding or "utf-8" + def readlines(self, filename, num_lines=0, binary=False, encoding=None): if num_lines > 0: cmd = f"tail -n {num_lines} {filename}" - lines = self.exec_command(cmd, encoding) else: - lines = self.read(filename, encoding=encoding).splitlines() + cmd = f"cat {filename}" + + result = self.exec_command(cmd, encoding=encoding) + + if not binary and result: + lines = result.decode(encoding or 'utf-8').splitlines() + else: + lines = result.splitlines() + return lines def isfile(self, remote_file): diff --git a/testgres/utils.py b/testgres/utils.py index 72fd1b9d..b72c7da0 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -13,8 +13,6 @@ from contextlib import contextmanager from packaging.version import Version -from .os_ops.remote_ops import RemoteOperations - try: from shutil import which as find_executable except ImportError: @@ -22,8 +20,10 @@ from six import iteritems from fabric import Connection -from .os_ops.local_ops import LocalOperations -from .os_ops.os_ops import OsOperations + +from .operations.remote_ops import RemoteOperations +from .operations.local_ops import LocalOperations +from .operations.os_ops import OsOperations from .config import testgres_config from .exceptions import ExecUtilException diff --git a/tests/test_remote.py b/tests/test_remote.py new file mode 100755 index 00000000..47804dfb --- /dev/null +++ b/tests/test_remote.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import time + +import pytest +from docker import DockerClient +from paramiko import RSAKey + +from testgres import RemoteOperations + + +class TestRemoteOperations: + @pytest.fixture(scope="class", autouse=True) + def setup_class(self): + # Create shared volume + self.volume_path = os.path.abspath("./tmp/ssh_key") + os.makedirs(self.volume_path, exist_ok=True) + + # Generate SSH keys + private_key_path = os.path.join(self.volume_path, "id_rsa") + public_key_path = os.path.join(self.volume_path, "id_rsa.pub") + + private_key = RSAKey.generate(4096) + private_key.write_private_key_file(private_key_path) + + with open(public_key_path, "w") as f: + f.write(f"{private_key.get_name()} {private_key.get_base64()}") + + self.docker = DockerClient.from_env() + self.container = self.docker.containers.run( + "rastasheep/ubuntu-sshd:18.04", + detach=True, + tty=True, + ports={22: 8022}, + ) + + # Wait for the container to start sshd + time.sleep(10) + + yield + + # Stop and remove the container after tests + self.container.stop() + self.container.remove() + + @pytest.fixture(scope="function", autouse=True) + def setup(self): + self.operations = RemoteOperations( + host="localhost", + username="root", + ssh_key=os.path.join(self.volume_path, "id_rsa") + ) + + yield + + self.operations.__del__() + + def test_exec_command(self): + cmd = "python3 --version" + response = self.operations.exec_command(cmd) + + assert "Python 3.9" in response + + def test_is_executable(self): + cmd = "python3" + response = self.operations.is_executable(cmd) + + assert response is True + + def test_makedirs_and_rmdirs(self): + path = "/test_dir" + + # Test makedirs + self.operations.makedirs(path) + assert self.operations.path_exists(path) + + # Test rmdirs + self.operations.rmdirs(path) + assert not self.operations.path_exists(path) diff --git a/tests/test_simple.py b/tests/test_simple.py index 94420b04..e8b8abee 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -151,8 +151,6 @@ def test_init_unique_system_id(self): self.assertGreater(id2, id1) def test_node_exit(self): - base_dir = None - with self.assertRaises(QueryException): with get_new_node().init() as node: base_dir = node.base_dir @@ -281,7 +279,7 @@ def test_psql(self): node.safe_psql('copy horns from stdin (format csv)', input=b"1\n2\n3\n\\.\n") _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(_sum, b'6\n') + self.assertEqual(b'6\n', _sum) # check psql's default args, fails with self.assertRaises(QueryException): @@ -614,7 +612,7 @@ def test_users(self): with get_new_node().init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') - self.assertEqual(value, b'1\n') + self.assertEqual(b'1\n', value) def test_poll_query_until(self): with get_new_node() as node: From e098b9796de56f62a5347cdd0fa6576a91ac7b40 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 13 Jun 2023 22:04:51 +0200 Subject: [PATCH 10/23] PBCKP-152 fix failed tests --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2b188565..b5162dce 100755 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup( version='1.9.0', name='testgres', - packages=['testgres'], + packages=['testgres', 'testgres.operations'], description='Testing utility for PostgreSQL and its extensions', url='https://github.com/postgrespro/testgres', long_description=readme, From 1c405ef7dde782d6dd6507708ceb5c35a2e33440 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 14 Jun 2023 11:15:33 +0200 Subject: [PATCH 11/23] PBCKP-152 add tests for remote_ops.py --- testgres/node.py | 6 +- testgres/operations/remote_ops.py | 243 +++++++++++++++++++----------- tests/test_remote.py | 210 +++++++++++++++++++------- 3 files changed, 313 insertions(+), 146 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 9aa47d84..9d183f96 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -152,10 +152,10 @@ def __init__(self, name=None, port=None, base_dir=None, self.host = host self.hostname = hostname self.ssh_key = ssh_key - if hostname == 'localhost' or host == '127.0.0.1': - self.os_ops = LocalOperations(username=username) - else: + if hostname != 'localhost' or host != '127.0.0.1': self.os_ops = RemoteOperations(host, hostname, ssh_key) + else: + self.os_ops = LocalOperations(username=username) testgres_config.os_ops = self.os_ops # defaults for __exit__() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index dbe88dbe..e2248015 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,106 +1,98 @@ -import io import os import tempfile -from contextlib import contextmanager +from typing import Optional -from testgres.logger import log +import paramiko +from paramiko import SSHClient +from logger import log from .os_ops import OsOperations from .os_ops import pglib -import paramiko +error_markers = [b'error', b'Permission denied'] class RemoteOperations(OsOperations): - """ - This class specifically supports work with Linux systems. It utilizes the SSH - for making connections and performing various file and directory operations, command executions, - environment setup and management, process control, and database connections. - It uses the Paramiko library for SSH connections and operations. - - Some methods are designed to work with specific Linux shell commands, and thus may not work as expected - on other non-Linux systems. - - Attributes: - - hostname (str): The remote system's hostname. Default 'localhost'. - - host (str): The remote system's IP address. Default '127.0.0.1'. - - ssh_key (str): Path to the SSH private key for authentication. - - username (str): Username for the remote system. - - ssh (paramiko.SSHClient): SSH connection to the remote system. - """ - - def __init__( - self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None - ): + def __init__(self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None): super().__init__(username) - self.hostname = hostname self.host = host self.ssh_key = ssh_key self.remote = True - self.ssh = self.connect() + self.ssh = self.ssh_connect() self.username = username or self.get_user() def __del__(self): if self.ssh: self.ssh.close() - @contextmanager - def ssh_connect(self): + def ssh_connect(self) -> Optional[SSHClient]: if not self.remote: - yield None + return None else: + key = self._read_ssh_key() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + return ssh + + def _read_ssh_key(self): + try: with open(self.ssh_key, "r") as f: key_data = f.read() if "BEGIN OPENSSH PRIVATE KEY" in key_data: key = paramiko.Ed25519Key.from_private_key_file(self.ssh_key) else: key = paramiko.RSAKey.from_private_key_file(self.ssh_key) + return key + except FileNotFoundError: + log.error(f"No such file or directory: '{self.ssh_key}'") + except Exception as e: + log.error(f"An error occurred while reading the ssh key: {e}") - with paramiko.SSHClient() as ssh: - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - yield ssh - - def connect(self): - with self.ssh_connect() as ssh: - return ssh - - # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, - expect_error=False, encoding=None, shell=True, text=False, - input=None, stdout=None, stderr=None, proc=None): + def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, + encoding=None, shell=True, text=False, input=None, stdout=None, + stderr=None, proc=None): + """ + Execute a command in the SSH session. + Args: + - cmd (str): The command to be executed. + """ if isinstance(cmd, list): cmd = " ".join(cmd) - log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") - # Source global profile file + execute command try: - cmd = f"source /etc/profile.d/custom.sh; {cmd}" - with self.ssh_connect() as ssh: - if input: - # encode input and feed it to stdin - stdin, stdout, stderr = ssh.exec_command(cmd) - stdin.write(input) - stdin.flush() - else: - stdin, stdout, stderr = ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() - if encoding: - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) - else: - # Save as binary string - result = io.BytesIO(stdout.read()).getvalue() - error = io.BytesIO(stderr.read()).getvalue() - error_str = stderr.read() + if input: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + stdin.write(input.encode("utf-8")) + stdin.flush() + else: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + result = stdout.read() + error = stderr.read() if expect_error: raise Exception(result, error) - if exit_status != 0 or 'error' in error_str: + + if encoding: + error_found = exit_status != 0 or any( + marker.decode(encoding) in error for marker in error_markers) + else: + error_found = exit_status != 0 or any( + marker in error for marker in error_markers) + + if error_found: log.error( f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" ) + if exit_status == 0: + exit_status = 1 if verbose: return exit_status, result, error @@ -112,7 +104,12 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, return None # Environment setup - def environ(self, var_name): + def environ(self, var_name: str) -> str: + """ + Get the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + """ cmd = f"echo ${var_name}" return self.exec_command(cmd).strip() @@ -131,7 +128,8 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - return self.exec_command(f"test -x {file} && echo OK") == "OK\n" + is_exec = self.exec_command(f"test -x {file} && echo OK") + return is_exec == b"OK\n" def add_to_path(self, new_path): pathsep = self.pathsep @@ -144,8 +142,13 @@ def add_to_path(self, new_path): os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep - def set_env(self, var_name, var_val): - # Check if the directory is already in PATH + def set_env(self, var_name: str, var_val: str) -> None: + """ + Set the value of an environment variable. + Args: + - var_name (str): The name of the environment variable. + - var_val (str): The value to be set for the environment variable. + """ return self.exec_command(f"export {var_name}={var_val}") # Get environment variables @@ -158,22 +161,47 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): + """ + Create a directory in the remote server. + Args: + - path (str): The path to the directory to be created. + - remove_existing (bool): If True, the existing directory at the path will be removed. + """ if remove_existing: cmd = f"rm -rf {path} && mkdir -p {path}" else: cmd = f"mkdir -p {path}" - return self.exec_command(cmd) + exit_status, result, error = self.exec_command(cmd, verbose=True) + if exit_status != 0: + raise Exception(f"Couldn't create dir {path} because of error {error}") + return result - def rmdirs(self, path, ignore_errors=True): + def rmdirs(self, path, verbose=False, ignore_errors=True): + """ + Remove a directory in the remote server. + Args: + - path (str): The path to the directory to be removed. + - verbose (bool): If True, return exit status, result, and error. + - ignore_errors (bool): If True, do not raise error if directory does not exist. + """ cmd = f"rm -rf {path}" - return self.exec_command(cmd) + exit_status, result, error = self.exec_command(cmd, verbose=True) + if verbose: + return exit_status, result, error + else: + return result def listdir(self, path): + """ + List all files and directories in a directory. + Args: + path (str): The path to the directory. + """ result = self.exec_command(f"ls {path}") return result.splitlines() def path_exists(self, path): - result = self.exec_command(f"test -e {path}; echo $?") + result = self.exec_command(f"test -e {path}; echo $?", encoding='utf-8') return int(result.strip()) == 0 @property @@ -188,7 +216,12 @@ def pathsep(self): return pathsep def mkdtemp(self, prefix=None): - temp_dir = self.exec_command(f"mkdtemp -d {prefix}") + """ + Creates a temporary directory in the remote server. + Args: + prefix (str): The prefix of the temporary directory name. + """ + temp_dir = self.exec_command(f"mkdtemp -d {prefix}", encoding='utf-8') return temp_dir.strip() def mkstemp(self, prefix=None): @@ -200,18 +233,19 @@ def copytree(self, src, dst): return self.exec_command(f"cp -r {src} {dst}") # Work with files - def write(self, filename, data, truncate=False, binary=False, read_and_write=False): + def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): """ Write data to a file on a remote host + Args: - filename: The file path where the data will be written. - data: The data to be written to the file. - truncate: If True, the file will be truncated before writing ('w' or 'wb' option); - if False (default), data will be appended ('a' or 'ab' option). - binary: If True, the data will be written in binary mode ('wb' or 'ab' option); - if False (default), the data will be written in text mode ('w' or 'a' option). - read_and_write: If True, the file will be opened with read and write permissions ('r+' option); - if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) + - filename (str): The file path where the data will be written. + - data (bytes or str): The data to be written to the file. + - truncate (bool): If True, the file will be truncated before writing ('w' or 'wb' option); + if False (default), data will be appended ('a' or 'ab' option). + - binary (bool): If True, the data will be written in binary mode ('wb' or 'ab' option); + if False (default), the data will be written in text mode ('w' or 'a' option). + - read_and_write (bool): If True, the file will be opened with read and write permissions ('r+' option); + if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option). """ mode = "wb" if binary else "w" if not truncate: @@ -220,15 +254,18 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal mode = "r+b" if binary else "r+" with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: - if isinstance(data, list): - tmp_file.writelines(data) - else: - tmp_file.write(data) + if isinstance(data, bytes) and not binary: + data = data.decode(encoding) + elif isinstance(data, str) and binary: + data = data.encode(encoding) + + tmp_file.write(data) tmp_file.flush() - sftp = self.ssh.open_sftp() - sftp.put(tmp_file.name, filename) - sftp.close() + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + sftp.put(tmp_file.name, filename) + sftp.close() def touch(self, filename): """ @@ -281,8 +318,29 @@ def get_pid(self): return self.exec_command("echo $$") # Database control - def db_connect(self, dbname, user, password=None, host="localhost", port=5432): - local_port = self.ssh.forward_remote_port(host, port) + def db_connect(self, dbname, user, password=None, host="127.0.0.1", hostname="localhost", port=5432): + """ + Connects to a PostgreSQL database on the remote system. + Args: + - dbname (str): The name of the database to connect to. + - user (str): The username for the database connection. + - password (str, optional): The password for the database connection. Defaults to None. + - host (str, optional): The IP address of the remote system. Defaults to "127.0.0.1". + - hostname (str, optional): The hostname of the remote system. Defaults to "localhost". + - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. + + This function establishes a connection to a PostgreSQL database on the remote system using the specified + parameters. It returns a connection object that can be used to interact with the database. + """ + transport = self.ssh.get_transport() + local_port = 9090 # or any other available port + + transport.open_channel( + 'direct-tcpip', + (hostname, port), + (host, local_port) + ) + conn = pglib.connect( host=host, port=local_port, @@ -291,3 +349,4 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432): password=password, ) return conn + diff --git a/tests/test_remote.py b/tests/test_remote.py index 47804dfb..0155956c 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -1,76 +1,66 @@ -#!/usr/bin/env python -# coding: utf-8 - -import os -import time - import pytest -from docker import DockerClient -from paramiko import RSAKey -from testgres import RemoteOperations +from testgres.operations.remote_ops import RemoteOperations class TestRemoteOperations: - @pytest.fixture(scope="class", autouse=True) - def setup_class(self): - # Create shared volume - self.volume_path = os.path.abspath("./tmp/ssh_key") - os.makedirs(self.volume_path, exist_ok=True) - - # Generate SSH keys - private_key_path = os.path.join(self.volume_path, "id_rsa") - public_key_path = os.path.join(self.volume_path, "id_rsa.pub") - - private_key = RSAKey.generate(4096) - private_key.write_private_key_file(private_key_path) - - with open(public_key_path, "w") as f: - f.write(f"{private_key.get_name()} {private_key.get_base64()}") - - self.docker = DockerClient.from_env() - self.container = self.docker.containers.run( - "rastasheep/ubuntu-sshd:18.04", - detach=True, - tty=True, - ports={22: 8022}, - ) - - # Wait for the container to start sshd - time.sleep(10) - - yield - - # Stop and remove the container after tests - self.container.stop() - self.container.remove() @pytest.fixture(scope="function", autouse=True) def setup(self): self.operations = RemoteOperations( - host="localhost", - username="root", - ssh_key=os.path.join(self.volume_path, "id_rsa") + host="172.18.0.3", + username="dev", + ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519' ) yield self.operations.__del__() - def test_exec_command(self): + def test_exec_command_success(self): + """ + Test exec_command for successful command execution. + """ cmd = "python3 --version" - response = self.operations.exec_command(cmd) + response = self.operations.exec_command(cmd, wait_exit=True) - assert "Python 3.9" in response + assert b'Python 3.' in response - def test_is_executable(self): - cmd = "python3" + def test_exec_command_failure(self): + """ + Test exec_command for command execution failure. + """ + cmd = "nonexistent_command" + exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) + + assert error == b'bash: line 1: nonexistent_command: command not found\n' + + def test_is_executable_true(self): + """ + Test is_executable for an existing executable. + """ + cmd = "postgres" response = self.operations.is_executable(cmd) assert response is True - def test_makedirs_and_rmdirs(self): - path = "/test_dir" + def test_is_executable_false(self): + """ + Test is_executable for a non-executable. + """ + cmd = "python" + response = self.operations.is_executable(cmd) + + assert response is False + + def test_makedirs_and_rmdirs_success(self): + """ + Test makedirs and rmdirs for successful directory creation and removal. + """ + cmd = "pwd" + pwd = self.operations.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() + + path = f"{pwd}/test_dir" # Test makedirs self.operations.makedirs(path) @@ -79,3 +69,121 @@ def test_makedirs_and_rmdirs(self): # Test rmdirs self.operations.rmdirs(path) assert not self.operations.path_exists(path) + + def test_makedirs_and_rmdirs_failure(self): + """ + Test makedirs and rmdirs for directory creation and removal failure. + """ + # Try to create a directory in a read-only location + path = "/root/test_dir" + + # Test makedirs + with pytest.raises(Exception): + self.operations.makedirs(path) + + # Test rmdirs + exit_status, result, error = self.operations.rmdirs(path, verbose=True) + assert error == b"rm: cannot remove '/root/test_dir': Permission denied\n" + + def test_listdir(self): + """ + Test listdir for listing directory contents. + """ + path = "/etc" + files = self.operations.listdir(path) + + assert isinstance(files, list) + + def test_path_exists_true(self): + """ + Test path_exists for an existing path. + """ + path = "/etc" + response = self.operations.path_exists(path) + + assert response is True + + def test_path_exists_false(self): + """ + Test path_exists for a non-existing path. + """ + path = "/nonexistent_path" + response = self.operations.path_exists(path) + + assert response is False + + def test_write_text_file(self): + """ + Test write for writing data to a text file. + """ + filename = "/tmp/test_file.txt" + data = "Hello, world!" + + self.operations.write(filename, data) + + response = self.operations.read(filename) + + assert response == data + + def test_write_binary_file(self): + """ + Test write for writing data to a binary file. + """ + filename = "/tmp/test_file.bin" + data = b"\x00\x01\x02\x03" + + self.operations.write(filename, data, binary=True) + + response = self.operations.read(filename, binary=True) + + assert response == data + + def test_read_text_file(self): + """ + Test read for reading data from a text file. + """ + filename = "/etc/hosts" + + response = self.operations.read(filename) + + assert isinstance(response, str) + + def test_read_binary_file(self): + """ + Test read for reading data from a binary file. + """ + filename = "/usr/bin/python3" + + response = self.operations.read(filename, binary=True) + + assert isinstance(response, bytes) + + def test_touch(self): + """ + Test touch for creating a new file or updating access and modification times of an existing file. + """ + filename = "/tmp/test_file.txt" + + self.operations.touch(filename) + + assert self.operations.isfile(filename) + + def test_isfile_true(self): + """ + Test isfile for an existing file. + """ + filename = "/etc/hosts" + + response = self.operations.isfile(filename) + + assert response is True + + def test_isfile_false(self): + """ + Test isfile for a non-existing file. + """ + filename = "/nonexistent_file.txt" + + response = self.operations.isfile(filename) + + assert response is False From 8c373e63b131aa9ddc20b3dd6aa1c350a4c9e347 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 14 Jun 2023 23:32:31 +0200 Subject: [PATCH 12/23] PBCKP-152 add testgres tests for remote node --- tests/test_simple_remote.py | 1006 +++++++++++++++++++++++++++++++++++ 1 file changed, 1006 insertions(+) create mode 100755 tests/test_simple_remote.py diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py new file mode 100755 index 00000000..179f3ffb --- /dev/null +++ b/tests/test_simple_remote.py @@ -0,0 +1,1006 @@ +#!/usr/bin/env python +# coding: utf-8 + +import os +import re +import subprocess +import tempfile + + +import testgres +import time +import six +import unittest +import psutil + +import logging.config + +from contextlib import contextmanager +from shutil import rmtree + +from testgres.exceptions import \ + InitNodeException, \ + StartNodeException, \ + ExecUtilException, \ + BackupException, \ + QueryException, \ + TimeoutException, \ + TestgresException + +from testgres.config import \ + TestgresConfig, \ + configure_testgres, \ + scoped_config, \ + pop_config + +from testgres import \ + NodeStatus, \ + ProcessType, \ + IsolationLevel, \ + get_new_node + +from testgres import \ + get_bin_path, \ + get_pg_config, \ + get_pg_version + +from testgres import \ + First, \ + Any + +# NOTE: those are ugly imports +from testgres import bound_ports +from testgres.utils import PgVer +from testgres.node import ProcessProxy + + +def pg_version_ge(version): + cur_ver = PgVer(get_pg_version()) + min_ver = PgVer(version) + return cur_ver >= min_ver + + +def util_exists(util): + def good_properties(f): + return (os.path.exists(f) and # noqa: W504 + os.path.isfile(f) and # noqa: W504 + os.access(f, os.X_OK)) # yapf: disable + + # try to resolve it + if good_properties(get_bin_path(util)): + return True + + # check if util is in PATH + for path in os.environ["PATH"].split(os.pathsep): + if good_properties(os.path.join(path, util)): + return True + + +@contextmanager +def removing(f): + try: + yield f + finally: + if os.path.isfile(f): + os.remove(f) + elif os.path.isdir(f): + rmtree(f, ignore_errors=True) + + +def get_remote_node(): + return get_new_node(host='172.18.0.3', username='dev', ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') + + +class TestgresRemoteTests(unittest.TestCase): + + def test_node_repr(self): + with get_remote_node() as node: + pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" + self.assertIsNotNone(re.match(pattern, str(node))) + + def test_custom_init(self): + with get_remote_node() as node: + # enable page checksums + node.init(initdb_params=['-k']).start() + + with get_remote_node() as node: + node.init( + allow_streaming=True, + initdb_params=['--auth-local=reject', '--auth-host=reject']) + + hba_file = os.path.join(node.data_dir, 'pg_hba.conf') + with open(hba_file, 'r') as conf: + lines = conf.readlines() + + # check number of lines + self.assertGreaterEqual(len(lines), 6) + + # there should be no trust entries at all + self.assertFalse(any('trust' in s for s in lines)) + + def test_double_init(self): + with get_remote_node().init() as node: + # can't initialize node more than once + with self.assertRaises(InitNodeException): + node.init() + + def test_init_after_cleanup(self): + with get_remote_node() as node: + node.init().start().execute('select 1') + node.cleanup() + node.init().start().execute('select 1') + + @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_init_unique_system_id(self): + # this function exists in PostgreSQL 9.6+ + query = 'select system_identifier from pg_control_system()' + + with scoped_config(cache_initdb=False): + with get_remote_node().init().start() as node0: + id0 = node0.execute(query)[0] + + with scoped_config(cache_initdb=True, + cached_initdb_unique=True) as config: + + self.assertTrue(config.cache_initdb) + self.assertTrue(config.cached_initdb_unique) + + # spawn two nodes; ids must be different + with get_remote_node().init().start() as node1, \ + get_remote_node().init().start() as node2: + + id1 = node1.execute(query)[0] + id2 = node2.execute(query)[0] + + # ids must increase + self.assertGreater(id1, id0) + self.assertGreater(id2, id1) + + def test_node_exit(self): + with self.assertRaises(QueryException): + with get_remote_node().init() as node: + base_dir = node.base_dir + node.safe_psql('select 1') + + # we should save the DB for "debugging" + self.assertTrue(os.path.exists(base_dir)) + rmtree(base_dir, ignore_errors=True) + + with get_remote_node().init() as node: + base_dir = node.base_dir + + # should have been removed by default + self.assertFalse(os.path.exists(base_dir)) + + def test_double_start(self): + with get_remote_node().init().start() as node: + # can't start node more than once + node.start() + self.assertTrue(node.is_started) + + def test_uninitialized_start(self): + with get_remote_node() as node: + # node is not initialized yet + with self.assertRaises(StartNodeException): + node.start() + + def test_restart(self): + with get_remote_node() as node: + node.init().start() + + # restart, ok + res = node.execute('select 1') + self.assertEqual(res, [(1, )]) + node.restart() + res = node.execute('select 2') + self.assertEqual(res, [(2, )]) + + # restart, fail + with self.assertRaises(StartNodeException): + node.append_conf('pg_hba.conf', 'DUMMY') + node.restart() + + def test_reload(self): + with get_remote_node() as node: + node.init().start() + + # change client_min_messages and save old value + cmm_old = node.execute('show client_min_messages') + node.append_conf(client_min_messages='DEBUG1') + + # reload config + node.reload() + + # check new value + cmm_new = node.execute('show client_min_messages') + self.assertEqual('debug1', cmm_new[0][0].lower()) + self.assertNotEqual(cmm_old, cmm_new) + + def test_pg_ctl(self): + with get_remote_node() as node: + node.init().start() + + status = node.pg_ctl(['status']) + self.assertTrue('PID' in status) + + def test_status(self): + self.assertTrue(NodeStatus.Running) + self.assertFalse(NodeStatus.Stopped) + self.assertFalse(NodeStatus.Uninitialized) + + # check statuses after each operation + with get_remote_node() as node: + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + node.init() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.start() + + self.assertNotEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Running) + + node.stop() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Stopped) + + node.cleanup() + + self.assertEqual(node.pid, 0) + self.assertEqual(node.status(), NodeStatus.Uninitialized) + + def test_psql(self): + with get_remote_node().init().start() as node: + + # check returned values (1 arg) + res = node.psql('select 1') + self.assertEqual(res, (0, b'1\n', b'')) + + # check returned values (2 args) + res = node.psql('postgres', 'select 2') + self.assertEqual(res, (0, b'2\n', b'')) + + # check returned values (named) + res = node.psql(query='select 3', dbname='postgres') + self.assertEqual(res, (0, b'3\n', b'')) + + # check returned values (1 arg) + res = node.safe_psql('select 4') + self.assertEqual(res, b'4\n') + + # check returned values (2 args) + res = node.safe_psql('postgres', 'select 5') + self.assertEqual(res, b'5\n') + + # check returned values (named) + res = node.safe_psql(query='select 6', dbname='postgres') + self.assertEqual(res, b'6\n') + + # check feeding input + node.safe_psql('create table horns (w int)') + node.safe_psql('copy horns from stdin (format csv)', + input=b"1\n2\n3\n\\.\n") + _sum = node.safe_psql('select sum(w) from horns') + self.assertEqual(b'6\n', _sum) + + # check psql's default args, fails + with self.assertRaises(QueryException): + node.psql() + + node.stop() + + # check psql on stopped node, fails + with self.assertRaises(QueryException): + node.safe_psql('select 1') + + def test_transactions(self): + with get_remote_node().init().start() as node: + + with node.connect() as con: + con.begin() + con.execute('create table test(val int)') + con.execute('insert into test values (1)') + con.commit() + + con.begin() + con.execute('insert into test values (2)') + res = con.execute('select * from test order by val asc') + self.assertListEqual(res, [(1, ), (2, )]) + con.rollback() + + con.begin() + res = con.execute('select * from test') + self.assertListEqual(res, [(1, )]) + con.rollback() + + con.begin() + con.execute('drop table test') + con.commit() + + def test_control_data(self): + with get_remote_node() as node: + + # node is not initialized yet + with self.assertRaises(ExecUtilException): + node.get_control_data() + + node.init() + data = node.get_control_data() + + # check returned dict + self.assertIsNotNone(data) + self.assertTrue(any('pg_control' in s for s in data.keys())) + + def test_backup_simple(self): + with get_remote_node() as master: + + # enable streaming for backups + master.init(allow_streaming=True) + + # node must be running + with self.assertRaises(BackupException): + master.backup() + + # it's time to start node + master.start() + + # fill node with some data + master.psql('create table test as select generate_series(1, 4) i') + + with master.backup(xlog_method='stream') as backup: + with backup.spawn_primary().start() as slave: + res = slave.execute('select * from test order by i asc') + self.assertListEqual(res, [(1, ), (2, ), (3, ), (4, )]) + + def test_backup_multiple(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup1, \ + node.backup(xlog_method='fetch') as backup2: + + self.assertNotEqual(backup1.base_dir, backup2.base_dir) + + with node.backup(xlog_method='fetch') as backup: + with backup.spawn_primary('node1', destroy=False) as node1, \ + backup.spawn_primary('node2', destroy=False) as node2: + + self.assertNotEqual(node1.base_dir, node2.base_dir) + + def test_backup_exhaust(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.backup(xlog_method='fetch') as backup: + + # exhaust backup by creating new node + with backup.spawn_primary(): + pass + + # now let's try to create one more node + with self.assertRaises(BackupException): + backup.spawn_primary() + + def test_backup_wrong_xlog_method(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with self.assertRaises(BackupException, + msg='Invalid xlog_method "wrong"'): + node.backup(xlog_method='wrong') + + def test_pg_ctl_wait_option(self): + with get_remote_node() as node: + node.init().start(wait=False) + while True: + try: + node.stop(wait=False) + break + except ExecUtilException: + # it's ok to get this exception here since node + # could be not started yet + pass + + def test_replicate(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.replicate().start() as replica: + res = replica.execute('select 1') + self.assertListEqual(res, [(1, )]) + + node.execute('create table test (val int)', commit=True) + + replica.catchup() + + res = node.execute('select * from test') + self.assertListEqual(res, []) + + @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') + def test_synchronous_replication(self): + with get_remote_node() as master: + old_version = not pg_version_ge('9.6') + + master.init(allow_streaming=True).start() + + if not old_version: + master.append_conf('synchronous_commit = remote_apply') + + # create standby + with master.replicate() as standby1, master.replicate() as standby2: + standby1.start() + standby2.start() + + # check formatting + self.assertEqual( + '1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(First(1, (standby1, standby2)))) # yapf: disable + self.assertEqual( + 'ANY 1 ("{}", "{}")'.format(standby1.name, standby2.name), + str(Any(1, (standby1, standby2)))) # yapf: disable + + # set synchronous_standby_names + master.set_synchronous_standbys(First(2, [standby1, standby2])) + master.restart() + + # the following part of the test is only applicable to newer + # versions of PostgresQL + if not old_version: + master.safe_psql('create table abc(a int)') + + # Create a large transaction that will take some time to apply + # on standby to check that it applies synchronously + # (If set synchronous_commit to 'on' or other lower level then + # standby most likely won't catchup so fast and test will fail) + master.safe_psql( + 'insert into abc select generate_series(1, 1000000)') + res = standby1.safe_psql('select count(*) from abc') + self.assertEqual(res, b'1000000\n') + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_replication(self): + with get_remote_node() as node1, get_remote_node() as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (a int, b int)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + + # create publication / create subscription + pub = node1.publish('mypub') + sub = node2.subscribe(pub, 'mysub') + + node1.safe_psql('insert into test values (1, 1), (2, 2)') + + # wait until changes apply on subscriber and check them + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2)]) + + # disable and put some new data + sub.disable() + node1.safe_psql('insert into test values (3, 3)') + + # enable and ensure that data successfully transfered + sub.enable() + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3)]) + + # Add new tables. Since we added "all tables" to publication + # (default behaviour of publish() method) we don't need + # to explicitely perform pub.add_tables() + create_table = 'create table test2 (c char)' + node1.safe_psql(create_table) + node2.safe_psql(create_table) + sub.refresh() + + # put new data + node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a', ), ('b', )]) + + # drop subscription + sub.drop() + pub.drop() + + # create new publication and subscription for specific table + # (ommitting copying data as it's already done) + pub = node1.publish('newpub', tables=['test']) + sub = node2.subscribe(pub, 'newsub', copy_data=False) + + node1.safe_psql('insert into test values (4, 4)') + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [(1, 1), (2, 2), (3, 3), (4, 4)]) + + # explicitely add table + with self.assertRaises(ValueError): + pub.add_tables([]) # fail + pub.add_tables(['test2']) + node1.safe_psql('insert into test2 values (\'c\')') + sub.catchup() + res = node2.execute('select * from test2') + self.assertListEqual(res, [('a', ), ('b', )]) + + @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') + def test_logical_catchup(self): + """ Runs catchup for 100 times to be sure that it is consistent """ + with get_remote_node() as node1, get_remote_node() as node2: + node1.init(allow_logical=True) + node1.start() + node2.init().start() + + create_table = 'create table test (key int primary key, val int); ' + node1.safe_psql(create_table) + node1.safe_psql('alter table test replica identity default') + node2.safe_psql(create_table) + + # create publication / create subscription + sub = node2.subscribe(node1.publish('mypub'), 'mysub') + + for i in range(0, 100): + node1.execute('insert into test values ({0}, {0})'.format(i)) + sub.catchup() + res = node2.execute('select * from test') + self.assertListEqual(res, [( + i, + i, + )]) + node1.execute('delete from test') + + @unittest.skipIf(pg_version_ge('10'), 'requires <10') + def test_logical_replication_fail(self): + with get_remote_node() as node: + with self.assertRaises(InitNodeException): + node.init(allow_logical=True) + + def test_replication_slots(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + with node.replicate(slot='slot1').start() as replica: + replica.execute('select 1') + + # cannot create new slot with the same name + with self.assertRaises(TestgresException): + node.replicate(slot='slot1') + + def test_incorrect_catchup(self): + with get_remote_node() as node: + node.init(allow_streaming=True).start() + + # node has no master, can't catch up + with self.assertRaises(TestgresException): + node.catchup() + + def test_promotion(self): + with get_remote_node() as master: + master.init().start() + master.safe_psql('create table abc(id serial)') + + with master.replicate().start() as replica: + master.stop() + replica.promote() + + # make standby becomes writable master + replica.safe_psql('insert into abc values (1)') + res = replica.safe_psql('select * from abc') + self.assertEqual(res, b'1\n') + + def test_dump(self): + query_create = 'create table test as select generate_series(1, 2) as val' + query_select = 'select * from test order by val asc' + + with get_remote_node().init().start() as node1: + + node1.execute(query_create) + for format in ['plain', 'custom', 'directory', 'tar']: + with removing(node1.dump(format=format)) as dump: + with get_remote_node().init().start() as node3: + if format == 'directory': + self.assertTrue(os.path.isdir(dump)) + else: + self.assertTrue(os.path.isfile(dump)) + # restore dump + node3.restore(filename=dump) + res = node3.execute(query_select) + self.assertListEqual(res, [(1, ), (2, )]) + + def test_users(self): + with get_remote_node().init().start() as node: + node.psql('create role test_user login') + value = node.safe_psql('select 1', username='test_user') + self.assertEqual(b'1\n', value) + + def test_poll_query_until(self): + with get_remote_node() as node: + node.init().start() + + get_time = 'select extract(epoch from now())' + check_time = 'select extract(epoch from now()) - {} >= 5' + + start_time = node.execute(get_time)[0][0] + node.poll_query_until(query=check_time.format(start_time)) + end_time = node.execute(get_time)[0][0] + + self.assertTrue(end_time - start_time >= 5) + + # check 0 columns + with self.assertRaises(QueryException): + node.poll_query_until( + query='select from pg_catalog.pg_class limit 1') + + # check None, fail + with self.assertRaises(QueryException): + node.poll_query_until(query='create table abc (val int)') + + # check None, ok + node.poll_query_until(query='create table def()', + expected=None) # returns nothing + + # check 0 rows equivalent to expected=None + node.poll_query_until( + query='select * from pg_catalog.pg_class where true = false', + expected=None) + + # check arbitrary expected value, fail + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 3', + expected=1, + max_attempts=3, + sleep_time=0.01) + + # check arbitrary expected value, ok + node.poll_query_until(query='select 2', expected=2) + + # check timeout + with self.assertRaises(TimeoutException): + node.poll_query_until(query='select 1 > 2', + max_attempts=3, + sleep_time=0.01) + + # check ProgrammingError, fail + with self.assertRaises(testgres.ProgrammingError): + node.poll_query_until(query='dummy1') + + # check ProgrammingError, ok + with self.assertRaises(TimeoutException): + node.poll_query_until(query='dummy2', + max_attempts=3, + sleep_time=0.01, + suppress={testgres.ProgrammingError}) + + # check 1 arg, ok + node.poll_query_until('select true') + + def test_logging(self): + logfile = tempfile.NamedTemporaryFile('w', delete=True) + + log_conf = { + 'version': 1, + 'handlers': { + 'file': { + 'class': 'logging.FileHandler', + 'filename': logfile.name, + 'formatter': 'base_format', + 'level': logging.DEBUG, + }, + }, + 'formatters': { + 'base_format': { + 'format': '%(node)-5s: %(message)s', + }, + }, + 'root': { + 'handlers': ('file', ), + 'level': 'DEBUG', + }, + } + + logging.config.dictConfig(log_conf) + + with scoped_config(use_python_logging=True): + node_name = 'master' + + with get_new_node(name=node_name) as master: + master.init().start() + + # execute a dummy query a few times + for i in range(20): + master.execute('select 1') + time.sleep(0.01) + + # let logging worker do the job + time.sleep(0.1) + + # check that master's port is found + with open(logfile.name, 'r') as log: + lines = log.readlines() + self.assertTrue(any(node_name in s for s in lines)) + + # test logger after stop/start/restart + master.stop() + master.start() + master.restart() + self.assertTrue(master._logger.is_alive()) + + @unittest.skipUnless(util_exists('pgbench'), 'might be missing') + def test_pgbench(self): + with get_remote_node().init().start() as node: + + # initialize pgbench DB and run benchmarks + node.pgbench_init(scale=2, foreign_keys=True, + options=['-q']).pgbench_run(time=2) + + # run TPC-B benchmark + proc = node.pgbench(stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + options=['-T3']) + + out, _ = proc.communicate() + out = out.decode('utf-8') + + self.assertTrue('tps' in out) + + def test_pg_config(self): + # check same instances + a = get_pg_config() + b = get_pg_config() + self.assertEqual(id(a), id(b)) + + # save right before config change + c1 = get_pg_config() + + # modify setting for this scope + with scoped_config(cache_pg_config=False) as config: + + # sanity check for value + self.assertFalse(config.cache_pg_config) + + # save right after config change + c2 = get_pg_config() + + # check different instances after config change + self.assertNotEqual(id(c1), id(c2)) + + # check different instances + a = get_pg_config() + b = get_pg_config() + self.assertNotEqual(id(a), id(b)) + + def test_config_stack(self): + # no such option + with self.assertRaises(TypeError): + configure_testgres(dummy=True) + + # we have only 1 config in stack + with self.assertRaises(IndexError): + pop_config() + + d0 = TestgresConfig.cached_initdb_dir + d1 = 'dummy_abc' + d2 = 'dummy_def' + + with scoped_config(cached_initdb_dir=d1) as c1: + self.assertEqual(c1.cached_initdb_dir, d1) + + with scoped_config(cached_initdb_dir=d2) as c2: + + stack_size = len(testgres.config.config_stack) + + # try to break a stack + with self.assertRaises(TypeError): + with scoped_config(dummy=True): + pass + + self.assertEqual(c2.cached_initdb_dir, d2) + self.assertEqual(len(testgres.config.config_stack), stack_size) + + self.assertEqual(c1.cached_initdb_dir, d1) + + self.assertEqual(TestgresConfig.cached_initdb_dir, d0) + + def test_unix_sockets(self): + with get_remote_node() as node: + node.init(unix_sockets=False, allow_streaming=True) + node.start() + + node.execute('select 1') + node.safe_psql('select 1') + + with node.replicate().start() as r: + r.execute('select 1') + r.safe_psql('select 1') + + def test_auto_name(self): + with get_remote_node().init(allow_streaming=True).start() as m: + with m.replicate().start() as r: + + # check that nodes are running + self.assertTrue(m.status()) + self.assertTrue(r.status()) + + # check their names + self.assertNotEqual(m.name, r.name) + self.assertTrue('testgres' in m.name) + self.assertTrue('testgres' in r.name) + + def test_file_tail(self): + from testgres.utils import file_tail + + s1 = "the quick brown fox jumped over that lazy dog\n" + s2 = "abc\n" + s3 = "def\n" + + with tempfile.NamedTemporaryFile(mode='r+', delete=True) as f: + sz = 0 + while sz < 3 * 8192: + sz += len(s1) + f.write(s1) + f.write(s2) + f.write(s3) + + f.seek(0) + lines = file_tail(f, 3) + self.assertEqual(lines[0], s1) + self.assertEqual(lines[1], s2) + self.assertEqual(lines[2], s3) + + f.seek(0) + lines = file_tail(f, 1) + self.assertEqual(lines[0], s3) + + def test_isolation_levels(self): + with get_remote_node().init().start() as node: + with node.connect() as con: + # string levels + con.begin('Read Uncommitted').commit() + con.begin('Read Committed').commit() + con.begin('Repeatable Read').commit() + con.begin('Serializable').commit() + + # enum levels + con.begin(IsolationLevel.ReadUncommitted).commit() + con.begin(IsolationLevel.ReadCommitted).commit() + con.begin(IsolationLevel.RepeatableRead).commit() + con.begin(IsolationLevel.Serializable).commit() + + # check wrong level + with self.assertRaises(QueryException): + con.begin('Garbage').commit() + + def test_ports_management(self): + # check that no ports have been bound yet + self.assertEqual(len(bound_ports), 0) + + with get_remote_node() as node: + # check that we've just bound a port + self.assertEqual(len(bound_ports), 1) + + # check that bound_ports contains our port + port_1 = list(bound_ports)[0] + port_2 = node.port + self.assertEqual(port_1, port_2) + + # check that port has been freed successfully + self.assertEqual(len(bound_ports), 0) + + def test_exceptions(self): + str(StartNodeException('msg', [('file', 'lines')])) + str(ExecUtilException('msg', 'cmd', 1, 'out')) + str(QueryException('msg', 'query')) + + def test_version_management(self): + a = PgVer('10.0') + b = PgVer('10') + c = PgVer('9.6.5') + d = PgVer('15.0') + e = PgVer('15rc1') + f = PgVer('15beta4') + + self.assertTrue(a == b) + self.assertTrue(b > c) + self.assertTrue(a > c) + self.assertTrue(d > e) + self.assertTrue(e > f) + self.assertTrue(d > f) + + version = get_pg_version() + with get_remote_node() as node: + self.assertTrue(isinstance(version, six.string_types)) + self.assertTrue(isinstance(node.version, PgVer)) + self.assertEqual(node.version, PgVer(version)) + + def test_child_pids(self): + master_processes = [ + ProcessType.AutovacuumLauncher, + ProcessType.BackgroundWriter, + ProcessType.Checkpointer, + ProcessType.StatsCollector, + ProcessType.WalSender, + ProcessType.WalWriter, + ] + + if pg_version_ge('10'): + master_processes.append(ProcessType.LogicalReplicationLauncher) + + repl_processes = [ + ProcessType.Startup, + ProcessType.WalReceiver, + ] + + with get_remote_node().init().start() as master: + + # master node doesn't have a source walsender! + with self.assertRaises(TestgresException): + master.source_walsender + + with master.connect() as con: + self.assertGreater(con.pid, 0) + + with master.replicate().start() as replica: + + # test __str__ method + str(master.child_processes[0]) + + master_pids = master.auxiliary_pids + for ptype in master_processes: + self.assertIn(ptype, master_pids) + + replica_pids = replica.auxiliary_pids + for ptype in repl_processes: + self.assertIn(ptype, replica_pids) + + # there should be exactly 1 source walsender for replica + self.assertEqual(len(master_pids[ProcessType.WalSender]), 1) + pid1 = master_pids[ProcessType.WalSender][0] + pid2 = replica.source_walsender.pid + self.assertEqual(pid1, pid2) + + replica.stop() + + # there should be no walsender after we've stopped replica + with self.assertRaises(TestgresException): + replica.source_walsender + + def test_child_process_dies(self): + # test for FileNotFound exception during child_processes() function + with subprocess.Popen(["sleep", "60"]) as process: + self.assertEqual(process.poll(), None) + # collect list of processes currently running + children = psutil.Process(os.getpid()).children() + # kill a process, so received children dictionary becomes invalid + process.kill() + process.wait() + # try to handle children list -- missing processes will have ptype "ProcessType.Unknown" + [ProcessProxy(p) for p in children] + + +if __name__ == '__main__': + if os.environ.get('ALT_CONFIG'): + suite = unittest.TestSuite() + + # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) + suite.addTest(TestgresTests('test_pg_config')) + suite.addTest(TestgresTests('test_pg_ctl')) + suite.addTest(TestgresTests('test_psql')) + suite.addTest(TestgresTests('test_replicate')) + + print('Running tests for alternative config:') + for t in suite: + print(t) + print() + + runner = unittest.TextTestRunner() + runner.run(suite) + else: + unittest.main() From 72e6d5d466bb76473b376c9bc4b9f36fa4afbda0 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Sat, 17 Jun 2023 02:08:23 +0200 Subject: [PATCH 13/23] PBCKP-152 fixed test_simple and test_remote --- setup.py | 3 +- testgres/backup.py | 4 +- testgres/cache.py | 7 +- testgres/config.py | 6 + testgres/connection.py | 2 +- testgres/defaults.py | 10 +- testgres/logger.py | 23 ++-- testgres/node.py | 97 ++++++++------- testgres/operations/local_ops.py | 101 +++++++++++----- testgres/operations/os_ops.py | 2 +- testgres/operations/remote_ops.py | 195 +++++++++++++++++++----------- testgres/utils.py | 109 +++-------------- tests/test_remote.py | 23 ++-- tests/test_simple.py | 14 +-- tests/test_simple_remote.py | 61 +++++----- 15 files changed, 353 insertions(+), 304 deletions(-) diff --git a/setup.py b/setup.py index b5162dce..8cb0f70a 100755 --- a/setup.py +++ b/setup.py @@ -13,7 +13,8 @@ "psutil", "packaging", "paramiko", - "fabric" + "fabric", + "sshtunnel" ] # Add compatibility enum class diff --git a/testgres/backup.py b/testgres/backup.py index c0fd6e50..c4cc952b 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -77,7 +77,7 @@ def __init__(self, "-D", data_dir, "-X", xlog_method.value ] # yapf: disable - execute_utility(_params, self.log_file, self.os_ops) + execute_utility(_params, self.log_file) def __enter__(self): return self @@ -139,7 +139,7 @@ def spawn_primary(self, name=None, destroy=True): # Build a new PostgresNode NodeClass = self.original_node.__class__ - with clean_on_error(NodeClass(name=name, base_dir=base_dir)) as node: + with clean_on_error(NodeClass(name=name, base_dir=base_dir, os_ops=self.original_node.os_ops)) as node: # New nodes should always remove dir tree node._should_rm_dirs = True diff --git a/testgres/cache.py b/testgres/cache.py index ef07e976..bf8658c9 100644 --- a/testgres/cache.py +++ b/testgres/cache.py @@ -26,12 +26,11 @@ def cached_initdb(data_dir, logfile=None, params=None, os_ops: OsOperations = Lo """ Perform initdb or use cached node files. """ - testgres_config.os_ops = os_ops - def call_initdb(initdb_dir, log=None): + def call_initdb(initdb_dir, log=logfile): try: _params = [get_bin_path("initdb"), "-D", initdb_dir, "-N"] - execute_utility(_params + (params or []), log, os_ops) + execute_utility(_params + (params or []), log) except ExecUtilException as e: raise_from(InitNodeException("Failed to run initdb"), e) @@ -62,7 +61,7 @@ def call_initdb(initdb_dir, log=None): # XXX: build new WAL segment with our system id _params = [get_bin_path("pg_resetwal"), "-D", data_dir, "-f"] - execute_utility(_params, logfile, os_ops=os_ops) + execute_utility(_params, logfile) except ExecUtilException as e: msg = "Failed to reset WAL for system id" diff --git a/testgres/config.py b/testgres/config.py index b21d8356..b6c43926 100644 --- a/testgres/config.py +++ b/testgres/config.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from .consts import TMP_CACHE +from .operations.os_ops import OsOperations from .operations.local_ops import LocalOperations @@ -121,6 +122,11 @@ def copy(self): return copy.copy(self) + @staticmethod + def set_os_ops(os_ops: OsOperations): + testgres_config.os_ops = os_ops + testgres_config.cached_initdb_dir = os_ops.mkdtemp(prefix=TMP_CACHE) + # cached dirs to be removed cached_initdb_dirs = set() diff --git a/testgres/connection.py b/testgres/connection.py index 6725b14f..d28d81bd 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -37,7 +37,7 @@ def __init__(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username(node.os_ops) + username = username or default_username() self._node = node diff --git a/testgres/defaults.py b/testgres/defaults.py index 34bcc08b..d77361d7 100644 --- a/testgres/defaults.py +++ b/testgres/defaults.py @@ -2,7 +2,7 @@ import struct import uuid -from .operations.local_ops import LocalOperations +from .config import testgres_config as tconf def default_dbname(): @@ -13,11 +13,11 @@ def default_dbname(): return 'postgres' -def default_username(os_ops=LocalOperations()): +def default_username(): """ Return default username (current user). """ - return os_ops.get_user() + return tconf.os_ops.get_user() def generate_app_name(): @@ -28,7 +28,7 @@ def generate_app_name(): return 'testgres-{}'.format(str(uuid.uuid4())) -def generate_system_id(os_ops=LocalOperations()): +def generate_system_id(): """ Generate a new 64-bit unique system identifier for node. """ @@ -43,7 +43,7 @@ def generate_system_id(os_ops=LocalOperations()): system_id = 0 system_id |= (secs << 32) system_id |= (usecs << 12) - system_id |= (os_ops.get_pid() & 0xFFF) + system_id |= (tconf.os_ops.get_pid() & 0xFFF) # pack ULL in native byte order return struct.pack('=Q', system_id) diff --git a/testgres/logger.py b/testgres/logger.py index abd4d255..59579002 100644 --- a/testgres/logger.py +++ b/testgres/logger.py @@ -5,19 +5,20 @@ import threading import time - # create logger log = logging.getLogger('Testgres') -log.setLevel(logging.DEBUG) -# create console handler and set level to debug -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) -# create formatter -formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') -# add formatter to ch -ch.setFormatter(formatter) -# add ch to logger -log.addHandler(ch) + +if not log.handlers: + log.setLevel(logging.WARN) + # create console handler and set level to debug + ch = logging.StreamHandler() + ch.setLevel(logging.WARN) + # create formatter + formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') + # add formatter to ch + ch.setFormatter(formatter) + # add ch to logger + log.addHandler(ch) class TestgresLogger(threading.Thread): diff --git a/testgres/node.py b/testgres/node.py index 9d183f96..5ad18ace 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -6,7 +6,6 @@ import threading from queue import Queue -import psutil import time try: @@ -23,7 +22,6 @@ except ImportError: raise ImportError("You must have psycopg2 or pg8000 modules installed") -from shutil import rmtree from six import raise_from, iteritems, text_type from .enums import \ @@ -128,7 +126,7 @@ def __repr__(self): class PostgresNode(object): def __init__(self, name=None, port=None, base_dir=None, - host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username()): + host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username(), os_ops=None): """ PostgresNode constructor. @@ -147,15 +145,19 @@ def __init__(self, name=None, port=None, base_dir=None, # basic self.name = name or generate_app_name() - self.port = port or reserve_port() - self.host = host - self.hostname = hostname - self.ssh_key = ssh_key - if hostname != 'localhost' or host != '127.0.0.1': - self.os_ops = RemoteOperations(host, hostname, ssh_key) + if os_ops: + self.os_ops = os_ops + elif ssh_key: + self.os_ops = RemoteOperations(host=host, hostname=hostname, ssh_key=ssh_key, username=username) else: - self.os_ops = LocalOperations(username=username) + self.os_ops = LocalOperations(host=host, hostname=hostname, username=username) + + self.port = self.os_ops.port or reserve_port() + + self.host = self.os_ops.host + self.hostname = self.os_ops.hostname + self.ssh_key = self.os_ops.ssh_key testgres_config.os_ops = self.os_ops # defaults for __exit__() @@ -243,7 +245,7 @@ def child_processes(self): """ # get a list of postmaster's children - children = psutil.Process(self.pid).children() + children = self.os_ops.get_remote_children(self.pid) return [ProcessProxy(p) for p in children] @@ -511,21 +513,18 @@ def get_auth_method(t): # get auth methods auth_local = get_auth_method('local') auth_host = get_auth_method('host') + subnet_base = ".".join(self.os_ops.host.split('.')[:-1] + ['0']) new_lines = [ u"local\treplication\tall\t\t\t{}\n".format(auth_local), u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), - - u"host\treplication\tall\t0.0.0.0/0\t{}\n".format(auth_host), - u"host\tall\tall\t0.0.0.0/0\t{}\n".format(auth_host), - - u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host) + u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), + u"host\treplication\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host), + u"host\tall\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host) ] # yapf: disable # write missing lines - for line in new_lines: - if line not in lines: - self.os_ops.write(hba_conf, line) + self.os_ops.write(hba_conf, new_lines) # overwrite config file self.os_ops.write(postgres_conf, '', truncate=True) @@ -533,7 +532,7 @@ def get_auth_method(t): self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, log_statement=log_statement, - listen_addresses=self.host, + listen_addresses='*', port=self.port) # yapf:disable # common replication settings @@ -598,9 +597,11 @@ def append_conf(self, line='', filename=PG_CONF_FILE, **kwargs): value = 'on' if value else 'off' elif not str(value).replace('.', '', 1).isdigit(): value = "'{}'".format(value) - - # format a new config line - lines.append('{} = {}'.format(option, value)) + if value == '*': + lines.append("{} = '*'".format(option)) + else: + # format a new config line + lines.append('{} = {}'.format(option, value)) config_name = os.path.join(self.data_dir, filename) conf_text = '' @@ -624,7 +625,9 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + out = execute_utility(_params, self.utils_log_file) + if 'no server running' in out: + return NodeStatus.Uninitialized return NodeStatus.Running except ExecUtilException as e: @@ -646,7 +649,7 @@ def get_control_data(self): _params += ["-D"] if self._pg_version >= PgVer('9.5') else [] _params += [self.data_dir] - data = execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + data = execute_utility(_params, self.utils_log_file) out_dict = {} @@ -709,8 +712,8 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) - except ExecUtilException as e: + execute_utility(_params, self.utils_log_file) + except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) @@ -740,7 +743,7 @@ def stop(self, params=[], wait=True): "stop" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) self._maybe_stop_logger() self.is_started = False @@ -782,7 +785,7 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() @@ -809,7 +812,7 @@ def reload(self, params=[]): "reload" ] + params # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) return self @@ -831,7 +834,7 @@ def promote(self, dbname=None, username=None): "promote" ] # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) # for versions below 10 `promote` is asynchronous so we need to wait # until it actually becomes writable @@ -866,7 +869,7 @@ def pg_ctl(self, params): "-w" # wait ] + params # yapf: disable - return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + return execute_utility(_params, self.utils_log_file) def free_port(self): """ @@ -898,7 +901,7 @@ def cleanup(self, max_attempts=3): else: rm_dir = self.data_dir # just data, save logs - rmtree(rm_dir, ignore_errors=True) + self.os_ops.rmdirs(rm_dir, ignore_errors=True) return self @@ -951,7 +954,7 @@ def psql(self, # select query source if query: - psql_params.extend(("-c", query)) + psql_params.extend(("-c", '"{}"'.format(query))) elif filename: psql_params.extend(("-f", filename)) else: @@ -961,7 +964,7 @@ def psql(self, psql_params.append(dbname) # start psql process - status_code, out, err = self.os_ops.exec_command(psql_params, shell=False, verbose=True, input=input) + status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) return status_code, out, err @@ -987,13 +990,17 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): # force this setting kwargs['ON_ERROR_STOP'] = 1 - - ret, out, err = self.psql(query=query, **kwargs) + try: + ret, out, err = self.psql(query=query, **kwargs) + except ExecUtilException as e: + ret = e.exit_code + out = e.out + err = e.message if ret: if expect_error: - out = (err or b'').decode('utf-8') + out = err or b'' else: - raise QueryException((err or b'').decode('utf-8'), query) + raise QueryException(err or b'', query) elif expect_error: assert False, f"Exception was expected, but query finished successfully: `{query}` " @@ -1049,7 +1056,7 @@ def tmpfile(): "-F", format.value ] # yapf: disable - execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_file) return filename @@ -1078,7 +1085,7 @@ def restore(self, filename, dbname=None, username=None): # try pg_restore if dump is binary formate, and psql if not try: - execute_utility(_params, self.utils_log_name, os_ops=self.os_ops) + execute_utility(_params, self.utils_log_name) except ExecUtilException: self.psql(filename=filename, dbname=dbname, username=username) @@ -1335,7 +1342,7 @@ def pgbench(self, # Set default arguments dbname = dbname or default_dbname() - username = username or default_username(self.os_ops) + username = username or default_username() _params = [ get_bin_path("pgbench"), @@ -1347,7 +1354,7 @@ def pgbench(self, # should be the last one _params.append(dbname) - proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, shell=False, proc=True) + proc = self.os_ops.exec_command(_params, stdout=stdout, stderr=stderr, wait_exit=True, proc=True) return proc @@ -1387,7 +1394,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # Set default arguments dbname = dbname or default_dbname() - username = username or default_username(os_ops=self.os_ops) + username = username or default_username() _params = [ get_bin_path("pgbench"), @@ -1410,7 +1417,7 @@ def pgbench_run(self, dbname=None, username=None, options=[], **kwargs): # should be the last one _params.append(dbname) - return execute_utility(_params, self.utils_log_file, os_ops=self.os_ops) + return execute_utility(_params, self.utils_log_file) def connect(self, dbname=None, diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index acb10df8..010e3cc0 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -5,32 +5,66 @@ import tempfile from shutil import rmtree +import psutil + +from testgres.exceptions import ExecUtilException from testgres.logger import log from .os_ops import OsOperations from .os_ops import pglib +try: + from shutil import which as find_executable +except ImportError: + from distutils.spawn import find_executable + CMD_TIMEOUT_SEC = 60 class LocalOperations(OsOperations): - def __init__(self, username=None): - super().__init__() + def __init__(self, host='127.0.0.1', hostname='localhost', port=None, username=None): + super().__init__(username) + self.host = host + self.hostname = hostname + self.port = port + self.ssh_key = None self.username = username or self.get_user() # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): - log.debug(f"os_ops.exec_command: `{cmd}`; remote={self.remote}") - # Source global profile file + execute command - try: + """ + Execute a command in a subprocess. + + Args: + - cmd: The command to execute. + - wait_exit: Whether to wait for the subprocess to exit before returning. + - verbose: Whether to return verbose output. + - expect_error: Whether to raise an error if the subprocess exits with an error status. + - encoding: The encoding to use for decoding the subprocess output. + - shell: Whether to use shell when executing the subprocess. + - text: Whether to return str instead of bytes for the subprocess output. + - input: The input to pass to the subprocess. + - stdout: The stdout to use for the subprocess. + - stderr: The stderr to use for the subprocess. + - proc: The process to use for subprocess creation. + :return: The output of the subprocess. + """ + if isinstance(cmd, list): + cmd = " ".join(cmd) + log.debug(f"Executing command: `{cmd}`") + + if os.name == 'nt': + with tempfile.NamedTemporaryFile() as buf: + process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT) + process.communicate() + buf.seek(0) + result = buf.read().decode(encoding) + return result + else: if proc: - return subprocess.Popen(cmd, - shell=shell, - stdin=input or subprocess.PIPE, - stdout=stdout, - stderr=stderr) + return subprocess.Popen(cmd, shell=shell, stdin=input, stdout=stdout, stderr=stderr) process = subprocess.run( cmd, input=input, @@ -43,41 +77,32 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, exit_status = process.returncode result = process.stdout error = process.stderr + found_error = "error" in error.decode(encoding or 'utf-8').lower() + if encoding: + result = result.decode(encoding) + error = error.decode(encoding) if expect_error: raise Exception(result, error) - if exit_status != 0 or "error" in error.lower().decode(encoding or 'utf-8'): # Decode error for comparison - log.error( - f"Problem in executing command: `{cmd}`\nerror: {error.decode(encoding or 'utf-8')}\nexit_code: {exit_status}" - # Decode for logging - ) - + if exit_status != 0 or found_error: + if exit_status == 0: + exit_status = 1 + raise ExecUtilException(message=f'Utility exited with non-zero code. Error `{error}`', + command=cmd, + exit_code=exit_status, + out=result) if verbose: return exit_status, result, error else: return result - except Exception as e: - log.error(f"Unexpected error while executing command `{cmd}`: {e}") - return None - # Environment setup def environ(self, var_name): cmd = f"echo ${var_name}" - return self.exec_command(cmd).strip() + return self.exec_command(cmd, encoding='utf-8').strip() def find_executable(self, executable): - search_paths = self.environ("PATH") - if not search_paths: - return None - - search_paths = search_paths.split(self.pathsep) - for path in search_paths: - remote_file = os.path.join(path, executable) - if self.isfile(remote_file): - return remote_file - - return None + return find_executable(executable) def is_executable(self, file): # Check if the file is executable @@ -157,6 +182,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal read_and_write: If True, the file will be opened with read and write permissions ('r+' option); if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option) """ + # If it is a bytes str or list + if isinstance(data, bytes) or isinstance(data, list) and all(isinstance(item, bytes) for item in data): + binary = True mode = "wb" if binary else "w" if not truncate: mode = "ab" if binary else "a" @@ -221,6 +249,12 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): def isfile(self, remote_file): return os.path.isfile(remote_file) + def isdir(self, dirname): + return os.path.isdir(dirname) + + def remove_file(self, filename): + return os.remove(filename) + # Processes control def kill(self, pid, signal): # Kill the process @@ -231,6 +265,9 @@ def get_pid(self): # Get current process id return os.getpid() + def get_remote_children(self, pid): + return psutil.Process(pid).children() + # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): conn = pglib.connect( diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 1ee1f869..68925616 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -15,7 +15,7 @@ def __init__(self, username=None): self.username = username # Command execution - def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False): + def exec_command(self, cmd, **kwargs): raise NotImplementedError() # Environment setup diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index e2248015..d45614a1 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -2,20 +2,30 @@ import tempfile from typing import Optional +import sshtunnel + import paramiko from paramiko import SSHClient -from logger import log +from testgres.exceptions import ExecUtilException +from testgres.logger import log + from .os_ops import OsOperations from .os_ops import pglib +sshtunnel.SSH_TIMEOUT = 5.0 +sshtunnel.TUNNEL_TIMEOUT = 5.0 + + error_markers = [b'error', b'Permission denied'] class RemoteOperations(OsOperations): - def __init__(self, hostname="localhost", host="127.0.0.1", ssh_key=None, username=None): + def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=None, username=None): super().__init__(username) self.host = host + self.hostname = hostname + self.port = port self.ssh_key = ssh_key self.remote = True self.ssh = self.ssh_connect() @@ -57,51 +67,50 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa Args: - cmd (str): The command to be executed. """ + if self.ssh is None or not self.ssh.get_transport() or not self.ssh.get_transport().is_active(): + self.ssh = self.ssh_connect() + if isinstance(cmd, list): cmd = " ".join(cmd) - try: - if input: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - stdin.write(input.encode("utf-8")) - stdin.flush() - else: - stdin, stdout, stderr = self.ssh.exec_command(cmd) - exit_status = 0 - if wait_exit: - exit_status = stdout.channel.recv_exit_status() - - if encoding: - result = stdout.read().decode(encoding) - error = stderr.read().decode(encoding) - else: - result = stdout.read() - error = stderr.read() + if input: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + stdin.write(input) + stdin.flush() + else: + stdin, stdout, stderr = self.ssh.exec_command(cmd) + exit_status = 0 + if wait_exit: + exit_status = stdout.channel.recv_exit_status() + + if encoding: + result = stdout.read().decode(encoding) + error = stderr.read().decode(encoding) + else: + result = stdout.read() + error = stderr.read() - if expect_error: - raise Exception(result, error) + if expect_error: + raise Exception(result, error) - if encoding: - error_found = exit_status != 0 or any( - marker.decode(encoding) in error for marker in error_markers) - else: - error_found = exit_status != 0 or any( - marker in error for marker in error_markers) - - if error_found: - log.error( - f"Problem in executing command: `{cmd}`\nerror: {error}\nexit_code: {exit_status}" - ) - if exit_status == 0: - exit_status = 1 - - if verbose: - return exit_status, result, error - else: - return result + if encoding: + error_found = exit_status != 0 or any( + marker.decode(encoding) in error for marker in error_markers) + else: + error_found = exit_status != 0 or any( + marker in error for marker in error_markers) - except Exception as e: - log.error(f"Unexpected error while executing command `{cmd}`: {e}") - return None + if error_found: + if exit_status == 0: + exit_status = 1 + raise ExecUtilException(message=f"Utility exited with non-zero code. Error: {error.decode(encoding or 'utf-8')}", + command=cmd, + exit_code=exit_status, + out=result) + + if verbose: + return exit_status, result, error + else: + return result # Environment setup def environ(self, var_name: str) -> str: @@ -111,7 +120,7 @@ def environ(self, var_name: str) -> str: - var_name (str): The name of the environment variable. """ cmd = f"echo ${var_name}" - return self.exec_command(cmd).strip() + return self.exec_command(cmd, encoding='utf-8').strip() def find_executable(self, executable): search_paths = self.environ("PATH") @@ -142,7 +151,7 @@ def add_to_path(self, new_path): os.environ["PATH"] = f"{new_path}{pathsep}{path}" return pathsep - def set_env(self, var_name: str, var_val: str) -> None: + def set_env(self, var_name: str, var_val: str): """ Set the value of an environment variable. Args: @@ -153,11 +162,11 @@ def set_env(self, var_name: str, var_val: str) -> None: # Get environment variables def get_user(self): - return self.exec_command("echo $USER") + return self.exec_command("echo $USER", encoding='utf-8').strip() def get_name(self): cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd).strip() + return self.exec_command(cmd, encoding='utf-8').strip() # Work with dirs def makedirs(self, path, remove_existing=False): @@ -219,10 +228,19 @@ def mkdtemp(self, prefix=None): """ Creates a temporary directory in the remote server. Args: - prefix (str): The prefix of the temporary directory name. + - prefix (str): The prefix of the temporary directory name. """ - temp_dir = self.exec_command(f"mkdtemp -d {prefix}", encoding='utf-8') - return temp_dir.strip() + if prefix: + temp_dir = self.exec_command(f"mktemp -d {prefix}XXXXX", encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp -d", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") def mkstemp(self, prefix=None): cmd = f"mktemp {prefix}XXXXXX" @@ -230,6 +248,10 @@ def mkstemp(self, prefix=None): return filename def copytree(self, src, dst): + if not os.path.isabs(dst): + dst = os.path.join('~', dst) + if self.isdir(dst): + raise FileExistsError(f"Directory {dst} already exists.") return self.exec_command(f"cp -r {src} {dst}") # Work with files @@ -253,20 +275,40 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal if read_and_write: mode = "r+b" if binary else "r+" - with tempfile.NamedTemporaryFile(mode=mode) as tmp_file: + with tempfile.NamedTemporaryFile(mode=mode, delete=False) as tmp_file: + if not truncate: + with self.ssh_connect() as ssh: + sftp = ssh.open_sftp() + try: + sftp.get(filename, tmp_file.name) + tmp_file.seek(0, os.SEEK_END) + except FileNotFoundError: + pass # File does not exist yet, we'll create it + sftp.close() if isinstance(data, bytes) and not binary: data = data.decode(encoding) elif isinstance(data, str) and binary: data = data.encode(encoding) - - tmp_file.write(data) + if isinstance(data, list): + # ensure each line ends with a newline + data = [s if s.endswith('\n') else s + '\n' for s in data] + tmp_file.writelines(data) + else: + tmp_file.write(data) tmp_file.flush() with self.ssh_connect() as ssh: sftp = ssh.open_sftp() + remote_directory = os.path.dirname(filename) + try: + sftp.stat(remote_directory) + except IOError: + sftp.mkdir(remote_directory) sftp.put(tmp_file.name, filename) sftp.close() + os.remove(tmp_file.name) + def touch(self, filename): """ Create a new file or update the access and modification times of an existing file on the remote server. @@ -307,6 +349,15 @@ def isfile(self, remote_file): result = int(stdout.strip()) return result == 0 + def isdir(self, dirname): + cmd = f"if [ -d {dirname} ]; then echo True; else echo False; fi" + response = self.exec_command(cmd, encoding='utf-8') + return response.strip() == "True" + + def remove_file(self, filename): + cmd = f"rm {filename}" + return self.exec_command(cmd) + # Processes control def kill(self, pid, signal): # Kill the process @@ -317,8 +368,14 @@ def get_pid(self): # Get current process id return self.exec_command("echo $$") + def get_remote_children(self, pid): + command = f"pgrep -P {pid}" + stdin, stdout, stderr = self.ssh.exec_command(command) + children = stdout.readlines() + return [int(child_pid.strip()) for child_pid in children] + # Database control - def db_connect(self, dbname, user, password=None, host="127.0.0.1", hostname="localhost", port=5432): + def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432): """ Connects to a PostgreSQL database on the remote system. Args: @@ -332,21 +389,19 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", hostname="lo This function establishes a connection to a PostgreSQL database on the remote system using the specified parameters. It returns a connection object that can be used to interact with the database. """ - transport = self.ssh.get_transport() - local_port = 9090 # or any other available port - - transport.open_channel( - 'direct-tcpip', - (hostname, port), - (host, local_port) - ) - - conn = pglib.connect( - host=host, - port=local_port, - database=dbname, - user=user, - password=password, - ) - return conn + with sshtunnel.open_tunnel( + (host, 22), # Remote server IP and SSH port + ssh_username=self.username, + ssh_pkey=self.ssh_key, + remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', port), # Local machine IP and available port + ): + conn = pglib.connect( + host=host, + port=port, + dbname=dbname, + user=user, + password=password + ) + return conn diff --git a/testgres/utils.py b/testgres/utils.py index b72c7da0..73c36e2c 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -3,30 +3,18 @@ from __future__ import division from __future__ import print_function -import io import os import port_for -import subprocess import sys -import tempfile from contextlib import contextmanager from packaging.version import Version -try: - from shutil import which as find_executable -except ImportError: - from distutils.spawn import find_executable from six import iteritems -from fabric import Connection -from .operations.remote_ops import RemoteOperations -from .operations.local_ops import LocalOperations -from .operations.os_ops import OsOperations - -from .config import testgres_config -from .exceptions import ExecUtilException +from .config import testgres_config as tconf +from .logger import log # rows returned by PG_CONFIG _pg_config_data = {} @@ -57,90 +45,34 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None, os_ops: OsOperations = LocalOperations()): +def execute_utility(args, logfile=None): """ Execute utility (pg_ctl, pg_dump etc). Args: - os_ops: LocalOperations for local node or RemoteOperations for node that connected by ssh. args: utility + arguments (list). logfile: path to file to store stdout and stderr. Returns: stdout of executed utility. """ - - if isinstance(os_ops, RemoteOperations): - conn = Connection( - os_ops.hostname, - connect_kwargs={ - "key_filename": f"{os_ops.ssh_key}", - }, - ) - # TODO skip remote ssh run if we are on the localhost. - # result = conn.run('hostname', hide=True) - # add logger - - cmd = ' '.join(args) - result = conn.run(cmd, hide=True) - - return result - - # run utility - if os.name == 'nt': - # using output to a temporary file in Windows - buf = tempfile.NamedTemporaryFile() - - process = subprocess.Popen( - args, # util + params - stdout=buf, - stderr=subprocess.STDOUT) - process.communicate() - - # get result - buf.file.flush() - buf.file.seek(0) - out = buf.file.read() - buf.close() - else: - process = subprocess.Popen( - args, # util + params - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) - - # get result - out, _ = process.communicate() - - # decode result - out = '' if not out else out.decode('utf-8') - - # format command command = u' '.join(args) + exit_status, out, error = tconf.os_ops.exec_command(command, verbose=True) + # decode result + out = '' if not out else out + if isinstance(out, bytes): + out = out.decode('utf-8') # write new log entry if possible if logfile: try: - with io.open(logfile, 'a') as file_out: - file_out.write(command) - - if out: - # comment-out lines - lines = ('# ' + line for line in out.splitlines(True)) - file_out.write(u'\n') - file_out.writelines(lines) - - file_out.write(u'\n') + tconf.os_ops.write(filename=logfile, data=command, truncate=True) + if out: + # comment-out lines + lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] + tconf.os_ops.write(filename=logfile, data=lines) except IOError: - pass - - exit_code = process.returncode - if exit_code: - message = 'Utility exited with non-zero code' - raise ExecUtilException(message=message, - command=command, - exit_code=exit_code, - out=out) - + log.warn(f"Problem with writing to logfile `{logfile}` during run command `{command}`") return out @@ -149,23 +81,22 @@ def get_bin_path(filename): Return absolute path to an executable using PG_BIN or PG_CONFIG. This function does nothing if 'filename' is already absolute. """ - # check if it's already absolute if os.path.isabs(filename): return filename - # try PG_CONFIG + # try PG_CONFIG - get from local machine pg_config = os.environ.get("PG_CONFIG") if pg_config: bindir = get_pg_config()["BINDIR"] return os.path.join(bindir, filename) # try PG_BIN - pg_bin = os.environ.get("PG_BIN") + pg_bin = tconf.os_ops.environ("PG_BIN") if pg_bin: return os.path.join(pg_bin, filename) - pg_config_path = find_executable('pg_config') + pg_config_path = tconf.os_ops.find_executable('pg_config') if pg_config_path: bindir = get_pg_config(pg_config_path)["BINDIR"] return os.path.join(bindir, filename) @@ -181,7 +112,7 @@ def get_pg_config(pg_config_path=None): def cache_pg_config_data(cmd): # execute pg_config and get the output - out = subprocess.check_output([cmd]).decode('utf-8') + out = tconf.os_ops.exec_command(cmd, encoding='utf-8') data = {} for line in out.splitlines(): @@ -196,7 +127,7 @@ def cache_pg_config_data(cmd): return data # drop cache if asked to - if not testgres_config.cache_pg_config: + if not tconf.cache_pg_config: global _pg_config_data _pg_config_data = {} @@ -226,7 +157,7 @@ def get_pg_version(): # get raw version (e.g. postgres (PostgreSQL) 9.5.7) _params = [get_bin_path('postgres'), '--version'] - raw_ver = subprocess.check_output(_params).decode('utf-8') + raw_ver = tconf.os_ops.exec_command(_params, encoding='utf-8') # cook version of PostgreSQL version = raw_ver.strip().split(' ')[-1] \ diff --git a/tests/test_remote.py b/tests/test_remote.py index 0155956c..7bc6b2f1 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -1,6 +1,7 @@ import pytest -from testgres.operations.remote_ops import RemoteOperations +from testgres import ExecUtilException +from testgres import RemoteOperations class TestRemoteOperations: @@ -31,9 +32,11 @@ def test_exec_command_failure(self): Test exec_command for command execution failure. """ cmd = "nonexistent_command" - exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) - - assert error == b'bash: line 1: nonexistent_command: command not found\n' + try: + exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) + except ExecUtilException as e: + error = e.message + assert error == 'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' def test_is_executable_true(self): """ @@ -82,8 +85,11 @@ def test_makedirs_and_rmdirs_failure(self): self.operations.makedirs(path) # Test rmdirs - exit_status, result, error = self.operations.rmdirs(path, verbose=True) - assert error == b"rm: cannot remove '/root/test_dir': Permission denied\n" + try: + exit_status, result, error = self.operations.rmdirs(path, verbose=True) + except ExecUtilException as e: + error = e.message + assert error == "Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" def test_listdir(self): """ @@ -119,11 +125,12 @@ def test_write_text_file(self): filename = "/tmp/test_file.txt" data = "Hello, world!" + self.operations.write(filename, data, truncate=True) self.operations.write(filename, data) response = self.operations.read(filename) - assert response == data + assert response == data + data def test_write_binary_file(self): """ @@ -132,7 +139,7 @@ def test_write_binary_file(self): filename = "/tmp/test_file.bin" data = b"\x00\x01\x02\x03" - self.operations.write(filename, data, binary=True) + self.operations.write(filename, data, binary=True, truncate=True) response = self.operations.read(filename, binary=True) diff --git a/tests/test_simple.py b/tests/test_simple.py index e8b8abee..2f8ff62b 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -252,34 +252,34 @@ def test_psql(self): # check returned values (1 arg) res = node.psql('select 1') - self.assertEqual(res, (0, b'1\n', b'')) + self.assertEqual((0, b'1\n', b''), res) # check returned values (2 args) res = node.psql('postgres', 'select 2') - self.assertEqual(res, (0, b'2\n', b'')) + self.assertEqual((0, b'2\n', b''), res) # check returned values (named) res = node.psql(query='select 3', dbname='postgres') - self.assertEqual(res, (0, b'3\n', b'')) + self.assertEqual((0, b'3\n', b''), res) # check returned values (1 arg) res = node.safe_psql('select 4') - self.assertEqual(res, b'4\n') + self.assertEqual(b'4\n', res) # check returned values (2 args) res = node.safe_psql('postgres', 'select 5') - self.assertEqual(res, b'5\n') + self.assertEqual(b'5\n', res) # check returned values (named) res = node.safe_psql(query='select 6', dbname='postgres') - self.assertEqual(res, b'6\n') + self.assertEqual(b'6\n', res) # check feeding input node.safe_psql('create table horns (w int)') node.safe_psql('copy horns from stdin (format csv)', input=b"1\n2\n3\n\\.\n") _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(b'6\n', _sum) + self.assertEqual(_sum, b'6\n') # check psql's default args, fails with self.assertRaises(QueryException): diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 179f3ffb..18c3450a 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -16,7 +16,6 @@ import logging.config from contextlib import contextmanager -from shutil import rmtree from testgres.exceptions import \ InitNodeException, \ @@ -31,13 +30,13 @@ TestgresConfig, \ configure_testgres, \ scoped_config, \ - pop_config + pop_config, testgres_config from testgres import \ NodeStatus, \ ProcessType, \ IsolationLevel, \ - get_new_node + get_new_node, RemoteOperations from testgres import \ get_bin_path, \ @@ -54,6 +53,12 @@ from testgres.node import ProcessProxy +os_ops = RemoteOperations(host='172.18.0.3', + username='dev', + ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') +testgres_config.set_os_ops(os_ops=os_ops) + + def pg_version_ge(version): cur_ver = PgVer(get_pg_version()) min_ver = PgVer(version) @@ -62,16 +67,16 @@ def pg_version_ge(version): def util_exists(util): def good_properties(f): - return (os.path.exists(f) and # noqa: W504 - os.path.isfile(f) and # noqa: W504 - os.access(f, os.X_OK)) # yapf: disable + return (os_ops.path_exists(f) and # noqa: W504 + os_ops.isfile(f) and # noqa: W504 + os_ops.is_executable(f)) # yapf: disable # try to resolve it if good_properties(get_bin_path(util)): return True # check if util is in PATH - for path in os.environ["PATH"].split(os.pathsep): + for path in os_ops.environ("PATH").split(os_ops.pathsep): if good_properties(os.path.join(path, util)): return True @@ -81,14 +86,15 @@ def removing(f): try: yield f finally: - if os.path.isfile(f): - os.remove(f) - elif os.path.isdir(f): - rmtree(f, ignore_errors=True) + if os_ops.isfile(f): + os_ops.remove_file(f) + + elif os_ops.isdir(f): + os_ops.rmdirs(f, ignore_errors=True) def get_remote_node(): - return get_new_node(host='172.18.0.3', username='dev', ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') + return get_new_node(host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) class TestgresRemoteTests(unittest.TestCase): @@ -109,14 +115,13 @@ def test_custom_init(self): initdb_params=['--auth-local=reject', '--auth-host=reject']) hba_file = os.path.join(node.data_dir, 'pg_hba.conf') - with open(hba_file, 'r') as conf: - lines = conf.readlines() + lines = os_ops.readlines(hba_file) - # check number of lines - self.assertGreaterEqual(len(lines), 6) + # check number of lines + self.assertGreaterEqual(len(lines), 6) - # there should be no trust entries at all - self.assertFalse(any('trust' in s for s in lines)) + # there should be no trust entries at all + self.assertFalse(any('trust' in s for s in lines)) def test_double_init(self): with get_remote_node().init() as node: @@ -164,14 +169,14 @@ def test_node_exit(self): node.safe_psql('select 1') # we should save the DB for "debugging" - self.assertTrue(os.path.exists(base_dir)) - rmtree(base_dir, ignore_errors=True) + self.assertTrue(os_ops.path_exists(base_dir)) + os_ops.rmdirs(base_dir, ignore_errors=True) with get_remote_node().init() as node: base_dir = node.base_dir # should have been removed by default - self.assertFalse(os.path.exists(base_dir)) + self.assertFalse(os_ops.path_exists(base_dir)) def test_double_start(self): with get_remote_node().init().start() as node: @@ -607,9 +612,9 @@ def test_dump(self): with removing(node1.dump(format=format)) as dump: with get_remote_node().init().start() as node3: if format == 'directory': - self.assertTrue(os.path.isdir(dump)) + self.assertTrue(os_ops.isdir(dump)) else: - self.assertTrue(os.path.isfile(dump)) + self.assertTrue(os_ops.isfile(dump)) # restore dump node3.restore(filename=dump) res = node3.execute(query_select) @@ -986,14 +991,14 @@ def test_child_process_dies(self): if __name__ == '__main__': - if os.environ.get('ALT_CONFIG'): + if os_ops.environ('ALT_CONFIG'): suite = unittest.TestSuite() # Small subset of tests for alternative configs (PG_BIN or PG_CONFIG) - suite.addTest(TestgresTests('test_pg_config')) - suite.addTest(TestgresTests('test_pg_ctl')) - suite.addTest(TestgresTests('test_psql')) - suite.addTest(TestgresTests('test_replicate')) + suite.addTest(TestgresRemoteTests('test_pg_config')) + suite.addTest(TestgresRemoteTests('test_pg_ctl')) + suite.addTest(TestgresRemoteTests('test_psql')) + suite.addTest(TestgresRemoteTests('test_replicate')) print('Running tests for alternative config:') for t in suite: From 2c2d2c5cf0eaf6a5a4c21fbf2d589adbdccbdbed Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Thu, 22 Jun 2023 23:14:55 +0200 Subject: [PATCH 14/23] PBCKP-588 test fix test_restore_after_failover --- testgres/node.py | 13 +++-- testgres/operations/local_ops.py | 19 ++----- testgres/operations/os_ops.py | 3 -- testgres/operations/remote_ops.py | 84 +++++++++++++++++++------------ testgres/utils.py | 16 +++--- 5 files changed, 74 insertions(+), 61 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 5ad18ace..2ab17c75 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -625,9 +625,11 @@ def status(self): "-D", self.data_dir, "status" ] # yapf: disable - out = execute_utility(_params, self.utils_log_file) - if 'no server running' in out: + status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in err: return NodeStatus.Uninitialized + elif'no server running' in out: + return NodeStatus.Stopped return NodeStatus.Running except ExecUtilException as e: @@ -712,14 +714,17 @@ def start(self, params=[], wait=True): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'does not exist' in error: + raise Exception + if 'server started' in out: + self.is_started = True except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) self._maybe_start_logger() - self.is_started = True return self def stop(self, params=[], wait=True): diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 010e3cc0..6a26910d 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -52,7 +52,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, :return: The output of the subprocess. """ if isinstance(cmd, list): - cmd = " ".join(cmd) + cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) log.debug(f"Executing command: `{cmd}`") if os.name == 'nt': @@ -98,8 +98,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, # Environment setup def environ(self, var_name): - cmd = f"echo ${var_name}" - return self.exec_command(cmd, encoding='utf-8').strip() + return os.environ.get(var_name) def find_executable(self, executable): return find_executable(executable) @@ -108,17 +107,6 @@ def is_executable(self, file): # Check if the file is executable return os.access(file, os.X_OK) - def add_to_path(self, new_path): - pathsep = self.pathsep - # Check if the directory is already in PATH - path = self.environ("PATH") - if new_path not in path.split(pathsep): - if self.remote: - self.exec_command(f"export PATH={new_path}{pathsep}{path}") - else: - os.environ["PATH"] = f"{new_path}{pathsep}{path}" - return pathsep - def set_env(self, var_name, var_val): # Check if the directory is already in PATH os.environ[var_name] = var_val @@ -128,8 +116,7 @@ def get_user(self): return getpass.getuser() def get_name(self): - cmd = 'python3 -c "import os; print(os.name)"' - return self.exec_command(cmd).strip() + return os.name # Work with dirs def makedirs(self, path, remove_existing=False): diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 68925616..c3f57653 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -29,9 +29,6 @@ def is_executable(self, file): # Check if the file is executable raise NotImplementedError() - def add_to_path(self, new_path): - raise NotImplementedError() - def set_env(self, var_name, var_val): # Check if the directory is already in PATH raise NotImplementedError() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index d45614a1..8e94a7fe 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -20,6 +20,22 @@ error_markers = [b'error', b'Permission denied'] +class PsUtilProcessProxy: + def __init__(self, ssh, pid): + self.ssh = ssh + self.pid = pid + + def kill(self): + command = f"kill {self.pid}" + self.ssh.exec_command(command) + + def cmdline(self): + command = f"ps -p {self.pid} -o cmd --no-headers" + stdin, stdout, stderr = self.ssh.exec_command(command) + cmdline = stdout.read().decode('utf-8').strip() + return cmdline.split() + + class RemoteOperations(OsOperations): def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=None, username=None): super().__init__(username) @@ -71,7 +87,7 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa self.ssh = self.ssh_connect() if isinstance(cmd, list): - cmd = " ".join(cmd) + cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) if input: stdin, stdout, stderr = self.ssh.exec_command(cmd) stdin.write(input) @@ -140,17 +156,6 @@ def is_executable(self, file): is_exec = self.exec_command(f"test -x {file} && echo OK") return is_exec == b"OK\n" - def add_to_path(self, new_path): - pathsep = self.pathsep - # Check if the directory is already in PATH - path = self.environ("PATH") - if new_path not in path.split(pathsep): - if self.remote: - self.exec_command(f"export PATH={new_path}{pathsep}{path}") - else: - os.environ["PATH"] = f"{new_path}{pathsep}{path}" - return pathsep - def set_env(self, var_name: str, var_val: str): """ Set the value of an environment variable. @@ -243,9 +248,17 @@ def mkdtemp(self, prefix=None): raise ExecUtilException("Could not create temporary directory.") def mkstemp(self, prefix=None): - cmd = f"mktemp {prefix}XXXXXX" - filename = self.exec_command(cmd).strip() - return filename + if prefix: + temp_dir = self.exec_command(f"mktemp {prefix}XXXXX", encoding='utf-8') + else: + temp_dir = self.exec_command("mktemp", encoding='utf-8') + + if temp_dir: + if not os.path.isabs(temp_dir): + temp_dir = os.path.join('/home', self.username, temp_dir.strip()) + return temp_dir + else: + raise ExecUtilException("Could not create temporary directory.") def copytree(self, src, dst): if not os.path.isabs(dst): @@ -291,7 +304,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal data = data.encode(encoding) if isinstance(data, list): # ensure each line ends with a newline - data = [s if s.endswith('\n') else s + '\n' for s in data] + data = [(s if isinstance(s, str) else s.decode('utf-8')).rstrip('\n') + '\n' for s in data] tmp_file.writelines(data) else: tmp_file.write(data) @@ -351,8 +364,8 @@ def isfile(self, remote_file): def isdir(self, dirname): cmd = f"if [ -d {dirname} ]; then echo True; else echo False; fi" - response = self.exec_command(cmd, encoding='utf-8') - return response.strip() == "True" + response = self.exec_command(cmd) + return response.strip() == b"True" def remove_file(self, filename): cmd = f"rm {filename}" @@ -366,16 +379,16 @@ def kill(self, pid, signal): def get_pid(self): # Get current process id - return self.exec_command("echo $$") + return int(self.exec_command("echo $$", encoding='utf-8')) def get_remote_children(self, pid): command = f"pgrep -P {pid}" stdin, stdout, stderr = self.ssh.exec_command(command) children = stdout.readlines() - return [int(child_pid.strip()) for child_pid in children] + return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] # Database control - def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432): + def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, ssh_key=None): """ Connects to a PostgreSQL database on the remote system. Args: @@ -389,19 +402,26 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432): This function establishes a connection to a PostgreSQL database on the remote system using the specified parameters. It returns a connection object that can be used to interact with the database. """ - with sshtunnel.open_tunnel( - (host, 22), # Remote server IP and SSH port - ssh_username=self.username, - ssh_pkey=self.ssh_key, - remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port - local_bind_address=('localhost', port), # Local machine IP and available port - ): + tunnel = sshtunnel.open_tunnel( + (host, 22), # Remote server IP and SSH port + ssh_username=user or self.username, + ssh_pkey=ssh_key or self.ssh_key, + remote_bind_address=(host, port), # PostgreSQL server IP and PostgreSQL port + local_bind_address=('localhost', port) # Local machine IP and available port + ) + + tunnel.start() + + try: conn = pglib.connect( - host=host, - port=port, + host=host, # change to 'localhost' because we're connecting through a local ssh tunnel + port=tunnel.local_bind_port, # use the local bind port set up by the tunnel dbname=dbname, - user=user, + user=user or self.username, password=password ) - return conn + return conn + except Exception as e: + tunnel.stop() + raise e diff --git a/testgres/utils.py b/testgres/utils.py index 73c36e2c..d8321b3e 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -45,7 +45,7 @@ def release_port(port): bound_ports.discard(port) -def execute_utility(args, logfile=None): +def execute_utility(args, logfile=None, verbose=False): """ Execute utility (pg_ctl, pg_dump etc). @@ -56,24 +56,28 @@ def execute_utility(args, logfile=None): Returns: stdout of executed utility. """ - command = u' '.join(args) - exit_status, out, error = tconf.os_ops.exec_command(command, verbose=True) + exit_status, out, error = tconf.os_ops.exec_command(args, verbose=True) # decode result out = '' if not out else out if isinstance(out, bytes): out = out.decode('utf-8') + if isinstance(error, bytes): + error = error.decode('utf-8') # write new log entry if possible if logfile: try: - tconf.os_ops.write(filename=logfile, data=command, truncate=True) + tconf.os_ops.write(filename=logfile, data=args, truncate=True) if out: # comment-out lines lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] tconf.os_ops.write(filename=logfile, data=lines) except IOError: - log.warn(f"Problem with writing to logfile `{logfile}` during run command `{command}`") - return out + log.warn(f"Problem with writing to logfile `{logfile}` during run command `{args}`") + if verbose: + return exit_status, out, error + else: + return out def get_bin_path(filename): From 1b4f74aa1f9eb48a7c19a07eba7a6a2083b5c26a Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Fri, 23 Jun 2023 01:55:22 +0200 Subject: [PATCH 15/23] PBCKP-588 test partially fixed test_simple_remote.py 41/43 --- testgres/node.py | 8 +++++--- testgres/operations/remote_ops.py | 30 +++++++++++++++++++++++++----- testgres/pubsub.py | 2 +- tests/test_simple_remote.py | 25 +++++++++++++------------ 4 files changed, 44 insertions(+), 21 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 2ab17c75..7a6e475c 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -519,8 +519,8 @@ def get_auth_method(t): u"local\treplication\tall\t\t\t{}\n".format(auth_local), u"host\treplication\tall\t127.0.0.1/32\t{}\n".format(auth_host), u"host\treplication\tall\t::1/128\t\t{}\n".format(auth_host), - u"host\treplication\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host), - u"host\tall\t{}\t{}/24\t\t{}\n".format(self.os_ops.username, subnet_base, auth_host) + u"host\treplication\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host), + u"host\tall\tall\t{}/24\t\t{}\n".format(subnet_base, auth_host) ] # yapf: disable # write missing lines @@ -790,7 +790,9 @@ def restart(self, params=[]): ] + params # yapf: disable try: - execute_utility(_params, self.utils_log_file) + error_code, out, error = execute_utility(_params, self.utils_log_file, verbose=True) + if 'could not start server' in error: + raise ExecUtilException except ExecUtilException as e: msg = 'Cannot restart node' files = self._collect_special_files() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 8e94a7fe..274a87cf 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -1,5 +1,6 @@ import os import tempfile +import time from typing import Optional import sshtunnel @@ -46,11 +47,29 @@ def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=No self.remote = True self.ssh = self.ssh_connect() self.username = username or self.get_user() + self.tunnel = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close_tunnel() + if getattr(self, 'ssh', None): + self.ssh.close() def __del__(self): - if self.ssh: + if getattr(self, 'ssh', None): self.ssh.close() + def close_tunnel(self): + if getattr(self, 'tunnel', None): + self.tunnel.stop(force=True) + start_time = time.time() + while self.tunnel.is_active: + if time.time() - start_time > sshtunnel.TUNNEL_TIMEOUT: + break + time.sleep(0.5) + def ssh_connect(self) -> Optional[SSHClient]: if not self.remote: return None @@ -402,7 +421,8 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s This function establishes a connection to a PostgreSQL database on the remote system using the specified parameters. It returns a connection object that can be used to interact with the database. """ - tunnel = sshtunnel.open_tunnel( + self.close_tunnel() + self.tunnel = sshtunnel.open_tunnel( (host, 22), # Remote server IP and SSH port ssh_username=user or self.username, ssh_pkey=ssh_key or self.ssh_key, @@ -410,12 +430,12 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s local_bind_address=('localhost', port) # Local machine IP and available port ) - tunnel.start() + self.tunnel.start() try: conn = pglib.connect( host=host, # change to 'localhost' because we're connecting through a local ssh tunnel - port=tunnel.local_bind_port, # use the local bind port set up by the tunnel + port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel dbname=dbname, user=user or self.username, password=password @@ -423,5 +443,5 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s return conn except Exception as e: - tunnel.stop() + self.tunnel.stop() raise e diff --git a/testgres/pubsub.py b/testgres/pubsub.py index da85caac..1be673bb 100644 --- a/testgres/pubsub.py +++ b/testgres/pubsub.py @@ -214,4 +214,4 @@ def catchup(self, username=None): username=username or self.pub.username, max_attempts=LOGICAL_REPL_MAX_CATCHUP_ATTEMPTS) except Exception as e: - raise_from(CatchUpException("Failed to catch up", query), e) + raise_from(CatchUpException("Failed to catch up"), e) diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 18c3450a..0b104ff0 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -6,7 +6,6 @@ import subprocess import tempfile - import testgres import time import six @@ -138,6 +137,7 @@ def test_init_after_cleanup(self): @unittest.skipUnless(util_exists('pg_resetwal'), 'might be missing') @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_init_unique_system_id(self): + # FAIL # this function exists in PostgreSQL 9.6+ query = 'select system_identifier from pg_control_system()' @@ -291,7 +291,7 @@ def test_psql(self): node.safe_psql('copy horns from stdin (format csv)', input=b"1\n2\n3\n\\.\n") _sum = node.safe_psql('select sum(w) from horns') - self.assertEqual(b'6\n', _sum) + self.assertEqual(_sum, b'6\n') # check psql's default args, fails with self.assertRaises(QueryException): @@ -688,6 +688,7 @@ def test_poll_query_until(self): node.poll_query_until('select true') def test_logging(self): + # FAIL logfile = tempfile.NamedTemporaryFile('w', delete=True) log_conf = { @@ -747,14 +748,11 @@ def test_pgbench(self): options=['-q']).pgbench_run(time=2) # run TPC-B benchmark - proc = node.pgbench(stdout=subprocess.PIPE, + out = node.pgbench(stdout=subprocess.PIPE, stderr=subprocess.STDOUT, options=['-T3']) - out, _ = proc.communicate() - out = out.decode('utf-8') - - self.assertTrue('tps' in out) + self.assertTrue(b'tps = ' in out) def test_pg_config(self): # check same instances @@ -764,7 +762,6 @@ def test_pg_config(self): # save right before config change c1 = get_pg_config() - # modify setting for this scope with scoped_config(cache_pg_config=False) as config: @@ -819,12 +816,16 @@ def test_unix_sockets(self): node.init(unix_sockets=False, allow_streaming=True) node.start() - node.execute('select 1') - node.safe_psql('select 1') + res_exec = node.execute('select 1') + res_psql = node.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') with node.replicate().start() as r: - r.execute('select 1') - r.safe_psql('select 1') + res_exec = r.execute('select 1') + res_psql = r.safe_psql('select 1') + self.assertEqual(res_exec, [(1,)]) + self.assertEqual(res_psql, b'1\n') def test_auto_name(self): with get_remote_node().init(allow_streaming=True).start() as m: From 2e916dfb4d044c24c421e80e4cddaa9799d0114d Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 26 Jun 2023 01:30:56 +0200 Subject: [PATCH 16/23] PBCKP-588 fixes after review --- testgres/logger.py | 15 --------------- testgres/node.py | 2 +- testgres/operations/local_ops.py | 4 +--- testgres/operations/os_ops.py | 3 +++ testgres/operations/remote_ops.py | 12 +++++------- testgres/utils.py | 5 ++--- tests/test_remote.py | 2 +- tests/test_simple.py | 16 +++++++++------- tests/test_simple_remote.py | 15 ++++++++------- 9 files changed, 30 insertions(+), 44 deletions(-) diff --git a/testgres/logger.py b/testgres/logger.py index 59579002..b4648f44 100644 --- a/testgres/logger.py +++ b/testgres/logger.py @@ -5,21 +5,6 @@ import threading import time -# create logger -log = logging.getLogger('Testgres') - -if not log.handlers: - log.setLevel(logging.WARN) - # create console handler and set level to debug - ch = logging.StreamHandler() - ch.setLevel(logging.WARN) - # create formatter - formatter = logging.Formatter('\n%(asctime)s - %(name)s[%(levelname)s]: %(message)s') - # add formatter to ch - ch.setFormatter(formatter) - # add ch to logger - log.addHandler(ch) - class TestgresLogger(threading.Thread): """ diff --git a/testgres/node.py b/testgres/node.py index 7a6e475c..5d3dfb72 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -245,7 +245,7 @@ def child_processes(self): """ # get a list of postmaster's children - children = self.os_ops.get_remote_children(self.pid) + children = self.os_ops.get_process_children(self.pid) return [ProcessProxy(p) for p in children] diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index 6a26910d..acd066cf 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -8,7 +8,6 @@ import psutil from testgres.exceptions import ExecUtilException -from testgres.logger import log from .os_ops import OsOperations from .os_ops import pglib @@ -53,7 +52,6 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, """ if isinstance(cmd, list): cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) - log.debug(f"Executing command: `{cmd}`") if os.name == 'nt': with tempfile.NamedTemporaryFile() as buf: @@ -252,7 +250,7 @@ def get_pid(self): # Get current process id return os.getpid() - def get_remote_children(self, pid): + def get_process_children(self, pid): return psutil.Process(pid).children() # Database control diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index c3f57653..4b1349b7 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -88,6 +88,9 @@ def get_pid(self): # Get current process id raise NotImplementedError() + def get_process_children(self, pid): + raise NotImplementedError() + # Database control def db_connect(self, dbname, user, password=None, host="localhost", port=5432): raise NotImplementedError() diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 274a87cf..b27ca47d 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -9,7 +9,6 @@ from paramiko import SSHClient from testgres.exceptions import ExecUtilException -from testgres.logger import log from .os_ops import OsOperations from .os_ops import pglib @@ -90,9 +89,9 @@ def _read_ssh_key(self): key = paramiko.RSAKey.from_private_key_file(self.ssh_key) return key except FileNotFoundError: - log.error(f"No such file or directory: '{self.ssh_key}'") + raise ExecUtilException(message=f"No such file or directory: '{self.ssh_key}'") except Exception as e: - log.error(f"An error occurred while reading the ssh key: {e}") + ExecUtilException(message=f"An error occurred while reading the ssh key: {e}") def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdout=None, @@ -400,7 +399,7 @@ def get_pid(self): # Get current process id return int(self.exec_command("echo $$", encoding='utf-8')) - def get_remote_children(self, pid): + def get_process_children(self, pid): command = f"pgrep -P {pid}" stdin, stdout, stderr = self.ssh.exec_command(command) children = stdout.readlines() @@ -414,8 +413,7 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s - dbname (str): The name of the database to connect to. - user (str): The username for the database connection. - password (str, optional): The password for the database connection. Defaults to None. - - host (str, optional): The IP address of the remote system. Defaults to "127.0.0.1". - - hostname (str, optional): The hostname of the remote system. Defaults to "localhost". + - host (str, optional): The IP address of the remote system. Defaults to "localhost". - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432. This function establishes a connection to a PostgreSQL database on the remote system using the specified @@ -444,4 +442,4 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s return conn except Exception as e: self.tunnel.stop() - raise e + raise ExecUtilException("Could not create db tunnel.") diff --git a/testgres/utils.py b/testgres/utils.py index d8321b3e..273c9287 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -12,9 +12,8 @@ from six import iteritems - +from .exceptions import ExecUtilException from .config import testgres_config as tconf -from .logger import log # rows returned by PG_CONFIG _pg_config_data = {} @@ -73,7 +72,7 @@ def execute_utility(args, logfile=None, verbose=False): lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] tconf.os_ops.write(filename=logfile, data=lines) except IOError: - log.warn(f"Problem with writing to logfile `{logfile}` during run command `{args}`") + raise ExecUtilException(f"Problem with writing to logfile `{logfile}` during run command `{args}`") if verbose: return exit_status, out, error else: diff --git a/tests/test_remote.py b/tests/test_remote.py index 7bc6b2f1..cdaa6574 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -11,7 +11,7 @@ def setup(self): self.operations = RemoteOperations( host="172.18.0.3", username="dev", - ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519' + ssh_key='../../container_files/postgres/ssh/id_ed25519' ) yield diff --git a/tests/test_simple.py b/tests/test_simple.py index 2f8ff62b..94420b04 100755 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -151,6 +151,8 @@ def test_init_unique_system_id(self): self.assertGreater(id2, id1) def test_node_exit(self): + base_dir = None + with self.assertRaises(QueryException): with get_new_node().init() as node: base_dir = node.base_dir @@ -252,27 +254,27 @@ def test_psql(self): # check returned values (1 arg) res = node.psql('select 1') - self.assertEqual((0, b'1\n', b''), res) + self.assertEqual(res, (0, b'1\n', b'')) # check returned values (2 args) res = node.psql('postgres', 'select 2') - self.assertEqual((0, b'2\n', b''), res) + self.assertEqual(res, (0, b'2\n', b'')) # check returned values (named) res = node.psql(query='select 3', dbname='postgres') - self.assertEqual((0, b'3\n', b''), res) + self.assertEqual(res, (0, b'3\n', b'')) # check returned values (1 arg) res = node.safe_psql('select 4') - self.assertEqual(b'4\n', res) + self.assertEqual(res, b'4\n') # check returned values (2 args) res = node.safe_psql('postgres', 'select 5') - self.assertEqual(b'5\n', res) + self.assertEqual(res, b'5\n') # check returned values (named) res = node.safe_psql(query='select 6', dbname='postgres') - self.assertEqual(b'6\n', res) + self.assertEqual(res, b'6\n') # check feeding input node.safe_psql('create table horns (w int)') @@ -612,7 +614,7 @@ def test_users(self): with get_new_node().init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') - self.assertEqual(b'1\n', value) + self.assertEqual(value, b'1\n') def test_poll_query_until(self): with get_new_node() as node: diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 0b104ff0..f86e623f 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -35,7 +35,8 @@ NodeStatus, \ ProcessType, \ IsolationLevel, \ - get_new_node, RemoteOperations + get_new_node, \ + RemoteOperations from testgres import \ get_bin_path, \ @@ -54,7 +55,7 @@ os_ops = RemoteOperations(host='172.18.0.3', username='dev', - ssh_key='/home/vika/Desktop/work/probackup/dev-ee-probackup/container_files/postgres/ssh/id_ed25519') + ssh_key='../../container_files/postgres/ssh/id_ed25519') testgres_config.set_os_ops(os_ops=os_ops) @@ -92,8 +93,8 @@ def removing(f): os_ops.rmdirs(f, ignore_errors=True) -def get_remote_node(): - return get_new_node(host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) +def get_remote_node(name=None): + return get_new_node(name=name, host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) class TestgresRemoteTests(unittest.TestCase): @@ -696,7 +697,7 @@ def test_logging(self): 'handlers': { 'file': { 'class': 'logging.FileHandler', - 'filename': logfile.name, + 'filename': logfile, 'formatter': 'base_format', 'level': logging.DEBUG, }, @@ -717,7 +718,7 @@ def test_logging(self): with scoped_config(use_python_logging=True): node_name = 'master' - with get_new_node(name=node_name) as master: + with get_remote_node(name=node_name) as master: master.init().start() # execute a dummy query a few times @@ -729,7 +730,7 @@ def test_logging(self): time.sleep(0.1) # check that master's port is found - with open(logfile.name, 'r') as log: + with open(logfile, 'r') as log: lines = log.readlines() self.assertTrue(any(node_name in s for s in lines)) From 0528541e70a3a3dbe3e019b7bc02552c84e40649 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 26 Jun 2023 02:23:41 +0200 Subject: [PATCH 17/23] PBCKP-588 fixes after review - add ConnectionParams --- testgres/__init__.py | 4 ++-- testgres/api.py | 9 ++++----- testgres/backup.py | 2 +- testgres/node.py | 17 +++++++---------- testgres/operations/local_ops.py | 15 +++++++-------- testgres/operations/os_ops.py | 11 ++++++++--- testgres/operations/remote_ops.py | 31 +++++++++++++------------------ tests/test_remote.py | 11 +++++------ tests/test_simple_remote.py | 12 ++++++------ 9 files changed, 53 insertions(+), 59 deletions(-) diff --git a/testgres/__init__.py b/testgres/__init__.py index 405262dd..ce2636b4 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -46,7 +46,7 @@ First, \ Any -from .operations.os_ops import OsOperations +from .operations.os_ops import OsOperations, ConnectionParams from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations @@ -60,5 +60,5 @@ "PostgresNode", "NodeApp", "reserve_port", "release_port", "bound_ports", "get_bin_path", "get_pg_config", "get_pg_version", "First", "Any", - "OsOperations", "LocalOperations", "RemoteOperations" + "OsOperations", "LocalOperations", "RemoteOperations", "ConnectionParams" ] diff --git a/testgres/api.py b/testgres/api.py index b5b76715..8f553529 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -37,11 +37,10 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ Simply a wrapper around :class:`.PostgresNode` constructor. See :meth:`.PostgresNode.__init__` for details. - For remote connection you can add next parameters: - host='127.0.0.1', - hostname='localhost', - ssh_key=None, - username=default_username() + For remote connection you can add the next parameter: + conn_params = ConnectionParams(host='127.0.0.1', + ssh_key=None, + username=default_username()) """ # NOTE: leave explicit 'name' and 'base_dir' for compatibility return PostgresNode(name=name, base_dir=base_dir, **kwargs) diff --git a/testgres/backup.py b/testgres/backup.py index c4cc952b..a89e214d 100644 --- a/testgres/backup.py +++ b/testgres/backup.py @@ -139,7 +139,7 @@ def spawn_primary(self, name=None, destroy=True): # Build a new PostgresNode NodeClass = self.original_node.__class__ - with clean_on_error(NodeClass(name=name, base_dir=base_dir, os_ops=self.original_node.os_ops)) as node: + with clean_on_error(NodeClass(name=name, base_dir=base_dir, conn_params=self.original_node.os_ops.conn_params)) as node: # New nodes should always remove dir tree node._should_rm_dirs = True diff --git a/testgres/node.py b/testgres/node.py index 5d3dfb72..d12e7324 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -94,6 +94,7 @@ from .backup import NodeBackup +from .operations.os_ops import ConnectionParams from .operations.local_ops import LocalOperations from .operations.remote_ops import RemoteOperations @@ -125,8 +126,7 @@ def __repr__(self): class PostgresNode(object): - def __init__(self, name=None, port=None, base_dir=None, - host='127.0.0.1', hostname='localhost', ssh_key=None, username=default_username(), os_ops=None): + def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionParams = ConnectionParams()): """ PostgresNode constructor. @@ -146,17 +146,14 @@ def __init__(self, name=None, port=None, base_dir=None, # basic self.name = name or generate_app_name() - if os_ops: - self.os_ops = os_ops - elif ssh_key: - self.os_ops = RemoteOperations(host=host, hostname=hostname, ssh_key=ssh_key, username=username) + if conn_params.ssh_key: + self.os_ops = RemoteOperations(conn_params) else: - self.os_ops = LocalOperations(host=host, hostname=hostname, username=username) + self.os_ops = LocalOperations(conn_params) - self.port = self.os_ops.port or reserve_port() + self.port = port or reserve_port() self.host = self.os_ops.host - self.hostname = self.os_ops.hostname self.ssh_key = self.os_ops.ssh_key testgres_config.os_ops = self.os_ops @@ -628,7 +625,7 @@ def status(self): status_code, out, err = execute_utility(_params, self.utils_log_file, verbose=True) if 'does not exist' in err: return NodeStatus.Uninitialized - elif'no server running' in out: + elif 'no server running' in out: return NodeStatus.Stopped return NodeStatus.Running diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index acd066cf..bbe6b0d4 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -7,9 +7,9 @@ import psutil -from testgres.exceptions import ExecUtilException +from ..exceptions import ExecUtilException -from .os_ops import OsOperations +from .os_ops import OsOperations, ConnectionParams from .os_ops import pglib try: @@ -21,13 +21,12 @@ class LocalOperations(OsOperations): - def __init__(self, host='127.0.0.1', hostname='localhost', port=None, username=None): - super().__init__(username) - self.host = host - self.hostname = hostname - self.port = port + def __init__(self, conn_params: ConnectionParams = ConnectionParams()): + super().__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host self.ssh_key = None - self.username = username or self.get_user() + self.username = conn_params.username or self.get_user() # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, diff --git a/testgres/operations/os_ops.py b/testgres/operations/os_ops.py index 4b1349b7..9261cacf 100644 --- a/testgres/operations/os_ops.py +++ b/testgres/operations/os_ops.py @@ -7,11 +7,16 @@ raise ImportError("You must have psycopg2 or pg8000 modules installed") +class ConnectionParams: + def __init__(self, host='127.0.0.1', ssh_key=None, username=None): + self.host = host + self.ssh_key = ssh_key + self.username = username + + class OsOperations: def __init__(self, username=None): - self.hostname = "localhost" - self.remote = False - self.ssh = None + self.ssh_key = None self.username = username # Command execution diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index b27ca47d..0a90426c 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -8,9 +8,9 @@ import paramiko from paramiko import SSHClient -from testgres.exceptions import ExecUtilException +from ..exceptions import ExecUtilException -from .os_ops import OsOperations +from .os_ops import OsOperations, ConnectionParams from .os_ops import pglib sshtunnel.SSH_TIMEOUT = 5.0 @@ -37,15 +37,13 @@ def cmdline(self): class RemoteOperations(OsOperations): - def __init__(self, host="127.0.0.1", hostname='localhost', port=None, ssh_key=None, username=None): - super().__init__(username) - self.host = host - self.hostname = hostname - self.port = port - self.ssh_key = ssh_key - self.remote = True + def __init__(self, conn_params: ConnectionParams): + super().__init__(conn_params.username) + self.conn_params = conn_params + self.host = conn_params.host + self.ssh_key = conn_params.ssh_key self.ssh = self.ssh_connect() - self.username = username or self.get_user() + self.username = conn_params.username or self.get_user() self.tunnel = None def __enter__(self): @@ -70,14 +68,11 @@ def close_tunnel(self): time.sleep(0.5) def ssh_connect(self) -> Optional[SSHClient]: - if not self.remote: - return None - else: - key = self._read_ssh_key() - ssh = paramiko.SSHClient() - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - ssh.connect(self.host, username=self.username, pkey=key) - return ssh + key = self._read_ssh_key() + ssh = paramiko.SSHClient() + ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + ssh.connect(self.host, username=self.username, pkey=key) + return ssh def _read_ssh_key(self): try: diff --git a/tests/test_remote.py b/tests/test_remote.py index cdaa6574..ceb06ee3 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -2,20 +2,19 @@ from testgres import ExecUtilException from testgres import RemoteOperations +from testgres import ConnectionParams class TestRemoteOperations: @pytest.fixture(scope="function", autouse=True) def setup(self): - self.operations = RemoteOperations( - host="172.18.0.3", - username="dev", - ssh_key='../../container_files/postgres/ssh/id_ed25519' - ) + conn_params = ConnectionParams(host="172.18.0.3", + username="dev", + ssh_key='../../container_files/postgres/ssh/id_ed25519') + self.operations = RemoteOperations(conn_params) yield - self.operations.__del__() def test_exec_command_success(self): diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index f86e623f..80cf7674 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -50,12 +50,12 @@ # NOTE: those are ugly imports from testgres import bound_ports from testgres.utils import PgVer -from testgres.node import ProcessProxy +from testgres.node import ProcessProxy, ConnectionParams - -os_ops = RemoteOperations(host='172.18.0.3', - username='dev', - ssh_key='../../container_files/postgres/ssh/id_ed25519') +conn_params = ConnectionParams(host="172.18.0.3", + username="dev", + ssh_key='../../container_files/postgres/ssh/id_ed25519') +os_ops = RemoteOperations(conn_params) testgres_config.set_os_ops(os_ops=os_ops) @@ -94,7 +94,7 @@ def removing(f): def get_remote_node(name=None): - return get_new_node(name=name, host=os_ops.host, username=os_ops.username, ssh_key=os_ops.ssh_key) + return get_new_node(name=name, conn_params=conn_params) class TestgresRemoteTests(unittest.TestCase): From 089ab9b73afbc1fa60c9bab88d0c3e7e52bd5fd4 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Mon, 26 Jun 2023 02:37:59 +0200 Subject: [PATCH 18/23] PBCKP-588 fixes after review - remove f-strings --- testgres/connection.py | 2 +- testgres/node.py | 2 +- testgres/operations/local_ops.py | 6 ++-- testgres/operations/remote_ops.py | 58 +++++++++++++++---------------- testgres/utils.py | 2 +- tests/test_simple_remote.py | 5 ++- 6 files changed, 37 insertions(+), 38 deletions(-) diff --git a/testgres/connection.py b/testgres/connection.py index d28d81bd..aeb040ce 100644 --- a/testgres/connection.py +++ b/testgres/connection.py @@ -111,7 +111,7 @@ def execute(self, query, *args): return res except Exception as e: - print(f"Error executing query: {e}") + print("Error executing query: {}".format(e)) return None def close(self): diff --git a/testgres/node.py b/testgres/node.py index d12e7324..0f709b17 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -1006,7 +1006,7 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): else: raise QueryException(err or b'', query) elif expect_error: - assert False, f"Exception was expected, but query finished successfully: `{query}` " + assert False, "Exception was expected, but query finished successfully: `{}` ".format(query) return out diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index bbe6b0d4..c2ee29cd 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -84,7 +84,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, if exit_status != 0 or found_error: if exit_status == 0: exit_status = 1 - raise ExecUtilException(message=f'Utility exited with non-zero code. Error `{error}`', + raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error), command=cmd, exit_code=exit_status, out=result) @@ -138,7 +138,7 @@ def pathsep(self): elif os_name == "nt": pathsep = ";" else: - raise Exception(f"Unsupported operating system: {os_name}") + raise Exception("Unsupported operating system: {}".format(os_name)) return pathsep def mkdtemp(self, prefix=None): @@ -242,7 +242,7 @@ def remove_file(self, filename): # Processes control def kill(self, pid, signal): # Kill the process - cmd = f"kill -{signal} {pid}" + cmd = "kill -{} {}".format(signal, pid) return self.exec_command(cmd) def get_pid(self): diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index 0a90426c..eb996f58 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -26,11 +26,11 @@ def __init__(self, ssh, pid): self.pid = pid def kill(self): - command = f"kill {self.pid}" + command = "kill {}".format(self.pid) self.ssh.exec_command(command) def cmdline(self): - command = f"ps -p {self.pid} -o cmd --no-headers" + command = "ps -p {} -o cmd --no-headers".format(self.pid) stdin, stdout, stderr = self.ssh.exec_command(command) cmdline = stdout.read().decode('utf-8').strip() return cmdline.split() @@ -84,9 +84,9 @@ def _read_ssh_key(self): key = paramiko.RSAKey.from_private_key_file(self.ssh_key) return key except FileNotFoundError: - raise ExecUtilException(message=f"No such file or directory: '{self.ssh_key}'") + raise ExecUtilException(message="No such file or directory: '{}'".format(self.ssh_key)) except Exception as e: - ExecUtilException(message=f"An error occurred while reading the ssh key: {e}") + ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, encoding=None, shell=True, text=False, input=None, stdout=None, @@ -131,7 +131,7 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa if error_found: if exit_status == 0: exit_status = 1 - raise ExecUtilException(message=f"Utility exited with non-zero code. Error: {error.decode(encoding or 'utf-8')}", + raise ExecUtilException(message="Utility exited with non-zero code. Error: {}".format(error.decode(encoding or 'utf-8')), command=cmd, exit_code=exit_status, out=result) @@ -148,7 +148,7 @@ def environ(self, var_name: str) -> str: Args: - var_name (str): The name of the environment variable. """ - cmd = f"echo ${var_name}" + cmd = "echo ${}".format(var_name) return self.exec_command(cmd, encoding='utf-8').strip() def find_executable(self, executable): @@ -166,7 +166,7 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - is_exec = self.exec_command(f"test -x {file} && echo OK") + is_exec = self.exec_command("test -x {} && echo OK".format(file)) return is_exec == b"OK\n" def set_env(self, var_name: str, var_val: str): @@ -176,7 +176,7 @@ def set_env(self, var_name: str, var_val: str): - var_name (str): The name of the environment variable. - var_val (str): The value to be set for the environment variable. """ - return self.exec_command(f"export {var_name}={var_val}") + return self.exec_command("export {}={}".format(var_name, var_val)) # Get environment variables def get_user(self): @@ -195,12 +195,12 @@ def makedirs(self, path, remove_existing=False): - remove_existing (bool): If True, the existing directory at the path will be removed. """ if remove_existing: - cmd = f"rm -rf {path} && mkdir -p {path}" + cmd = "rm -rf {} && mkdir -p {}".format(path, path) else: - cmd = f"mkdir -p {path}" + cmd = "mkdir -p {}".format(path) exit_status, result, error = self.exec_command(cmd, verbose=True) if exit_status != 0: - raise Exception(f"Couldn't create dir {path} because of error {error}") + raise Exception("Couldn't create dir {} because of error {}".format(path, error)) return result def rmdirs(self, path, verbose=False, ignore_errors=True): @@ -211,7 +211,7 @@ def rmdirs(self, path, verbose=False, ignore_errors=True): - verbose (bool): If True, return exit status, result, and error. - ignore_errors (bool): If True, do not raise error if directory does not exist. """ - cmd = f"rm -rf {path}" + cmd = "rm -rf {}".format(path) exit_status, result, error = self.exec_command(cmd, verbose=True) if verbose: return exit_status, result, error @@ -224,11 +224,11 @@ def listdir(self, path): Args: path (str): The path to the directory. """ - result = self.exec_command(f"ls {path}") + result = self.exec_command("ls {}".format(path)) return result.splitlines() def path_exists(self, path): - result = self.exec_command(f"test -e {path}; echo $?", encoding='utf-8') + result = self.exec_command("test -e {}; echo $?".format(path), encoding='utf-8') return int(result.strip()) == 0 @property @@ -239,7 +239,7 @@ def pathsep(self): elif os_name == "nt": pathsep = ";" else: - raise Exception(f"Unsupported operating system: {os_name}") + raise Exception("Unsupported operating system: {}".format(os_name)) return pathsep def mkdtemp(self, prefix=None): @@ -249,7 +249,7 @@ def mkdtemp(self, prefix=None): - prefix (str): The prefix of the temporary directory name. """ if prefix: - temp_dir = self.exec_command(f"mktemp -d {prefix}XXXXX", encoding='utf-8') + temp_dir = self.exec_command("mktemp -d {}XXXXX".format(prefix), encoding='utf-8') else: temp_dir = self.exec_command("mktemp -d", encoding='utf-8') @@ -262,7 +262,7 @@ def mkdtemp(self, prefix=None): def mkstemp(self, prefix=None): if prefix: - temp_dir = self.exec_command(f"mktemp {prefix}XXXXX", encoding='utf-8') + temp_dir = self.exec_command("mktemp {}XXXXX".format(prefix), encoding='utf-8') else: temp_dir = self.exec_command("mktemp", encoding='utf-8') @@ -277,8 +277,8 @@ def copytree(self, src, dst): if not os.path.isabs(dst): dst = os.path.join('~', dst) if self.isdir(dst): - raise FileExistsError(f"Directory {dst} already exists.") - return self.exec_command(f"cp -r {src} {dst}") + raise FileExistsError("Directory {} already exists.".format(dst)) + return self.exec_command("cp -r {} {}".format(src, dst)) # Work with files def write(self, filename, data, truncate=False, binary=False, read_and_write=False, encoding='utf-8'): @@ -344,10 +344,10 @@ def touch(self, filename): This method behaves as the 'touch' command in Unix. It's equivalent to calling 'touch filename' in the shell. """ - self.exec_command(f"touch {filename}") + self.exec_command("touch {}".format(filename)) def read(self, filename, binary=False, encoding=None): - cmd = f"cat {filename}" + cmd = "cat {}".format(filename) result = self.exec_command(cmd, encoding=encoding) if not binary and result: @@ -357,9 +357,9 @@ def read(self, filename, binary=False, encoding=None): def readlines(self, filename, num_lines=0, binary=False, encoding=None): if num_lines > 0: - cmd = f"tail -n {num_lines} {filename}" + cmd = "tail -n {} {}".format(num_lines, filename) else: - cmd = f"cat {filename}" + cmd = "cat {}".format(filename) result = self.exec_command(cmd, encoding=encoding) @@ -371,23 +371,23 @@ def readlines(self, filename, num_lines=0, binary=False, encoding=None): return lines def isfile(self, remote_file): - stdout = self.exec_command(f"test -f {remote_file}; echo $?") + stdout = self.exec_command("test -f {}; echo $?".format(remote_file)) result = int(stdout.strip()) return result == 0 def isdir(self, dirname): - cmd = f"if [ -d {dirname} ]; then echo True; else echo False; fi" + cmd = "if [ -d {} ]; then echo True; else echo False; fi".format(dirname) response = self.exec_command(cmd) return response.strip() == b"True" def remove_file(self, filename): - cmd = f"rm {filename}" + cmd = "rm {}".format(filename) return self.exec_command(cmd) # Processes control def kill(self, pid, signal): # Kill the process - cmd = f"kill -{signal} {pid}" + cmd = "kill -{} {}".format(signal, pid) return self.exec_command(cmd) def get_pid(self): @@ -395,7 +395,7 @@ def get_pid(self): return int(self.exec_command("echo $$", encoding='utf-8')) def get_process_children(self, pid): - command = f"pgrep -P {pid}" + command = "pgrep -P {}".format(pid) stdin, stdout, stderr = self.ssh.exec_command(command) children = stdout.readlines() return [PsUtilProcessProxy(self.ssh, int(child_pid.strip())) for child_pid in children] @@ -437,4 +437,4 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s return conn except Exception as e: self.tunnel.stop() - raise ExecUtilException("Could not create db tunnel.") + raise ExecUtilException("Could not create db tunnel. {}".format(e)) diff --git a/testgres/utils.py b/testgres/utils.py index 273c9287..1772d748 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -72,7 +72,7 @@ def execute_utility(args, logfile=None, verbose=False): lines = [u'\n'] + ['# ' + line for line in out.splitlines()] + [u'\n'] tconf.os_ops.write(filename=logfile, data=lines) except IOError: - raise ExecUtilException(f"Problem with writing to logfile `{logfile}` during run command `{args}`") + raise ExecUtilException("Problem with writing to logfile `{}` during run command `{}`".format(logfile, args)) if verbose: return exit_status, out, error else: diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 80cf7674..448a60ca 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -750,9 +750,8 @@ def test_pgbench(self): # run TPC-B benchmark out = node.pgbench(stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - options=['-T3']) - + stderr=subprocess.STDOUT, + options=['-T3']) self.assertTrue(b'tps = ' in out) def test_pg_config(self): From 190d084a7dbb00f4b844a5d3392194f47fd26073 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Tue, 27 Jun 2023 23:43:18 +0200 Subject: [PATCH 19/23] PBCKP-588 fixes after review - replace subprocess.run on subprocess.Popen --- testgres/operations/local_ops.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index c2ee29cd..fb47194f 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -60,27 +60,26 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, result = buf.read().decode(encoding) return result else: - if proc: - return subprocess.Popen(cmd, shell=shell, stdin=input, stdout=stdout, stderr=stderr) - process = subprocess.run( + process = subprocess.Popen( cmd, - input=input, shell=shell, - text=text, stdout=stdout, stderr=stderr, - timeout=CMD_TIMEOUT_SEC, ) + if proc: + return process + result, error = process.communicate(input) exit_status = process.returncode - result = process.stdout - error = process.stderr + found_error = "error" in error.decode(encoding or 'utf-8').lower() + if encoding: result = result.decode(encoding) error = error.decode(encoding) if expect_error: raise Exception(result, error) + if exit_status != 0 or found_error: if exit_status == 0: exit_status = 1 From 0c26f77db688c57f66bc2029dc92737612e8d736 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 28 Jun 2023 16:21:12 +0200 Subject: [PATCH 20/23] PBCKP-588 fix failed tests - psql, set_auto_conf --- testgres/node.py | 40 +++++++++++++++++--------- testgres/operations/local_ops.py | 36 ++++++++++++----------- testgres/operations/remote_ops.py | 13 ++++++--- tests/test_remote.py | 15 ++++++---- tests/test_simple_remote.py | 48 ++++++++++++------------------- 5 files changed, 82 insertions(+), 70 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index 0f709b17..a146b08d 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -3,6 +3,7 @@ import os import random import signal +import subprocess import threading from queue import Queue @@ -714,14 +715,13 @@ def start(self, params=[], wait=True): exit_status, out, error = execute_utility(_params, self.utils_log_file, verbose=True) if 'does not exist' in error: raise Exception - if 'server started' in out: - self.is_started = True except Exception as e: msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) self._maybe_start_logger() + self.is_started = True return self def stop(self, params=[], wait=True): @@ -958,7 +958,10 @@ def psql(self, # select query source if query: - psql_params.extend(("-c", '"{}"'.format(query))) + if self.os_ops.remote: + psql_params.extend(("-c", '"{}"'.format(query))) + else: + psql_params.extend(("-c", query)) elif filename: psql_params.extend(("-f", filename)) else: @@ -966,11 +969,20 @@ def psql(self, # should be the last one psql_params.append(dbname) + if not self.os_ops.remote: + # start psql process + process = subprocess.Popen(psql_params, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + + # wait until it finishes and get stdout and stderr + out, err = process.communicate(input=input) + return process.returncode, out, err + else: + status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) - # start psql process - status_code, out, err = self.os_ops.exec_command(psql_params, verbose=True, input=input) - - return status_code, out, err + return status_code, out, err @method_decorator(positional_args_hack(['dbname', 'query'])) def safe_psql(self, query=None, expect_error=False, **kwargs): @@ -1002,9 +1014,9 @@ def safe_psql(self, query=None, expect_error=False, **kwargs): err = e.message if ret: if expect_error: - out = err or b'' + out = (err or b'').decode('utf-8') else: - raise QueryException(err or b'', query) + raise QueryException((err or b'').decode('utf-8'), query) elif expect_error: assert False, "Exception was expected, but query finished successfully: `{}` ".format(query) @@ -1529,18 +1541,18 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): Defaults to an empty set. """ # parse postgresql.auto.conf - auto_conf_file = os.path.join(self.data_dir, config) - raw_content = self.os_ops.read(auto_conf_file) + path = os.path.join(self.data_dir, config) + lines = self.os_ops.readlines(path) current_options = {} current_directives = [] - for line in raw_content.splitlines(): + for line in lines: # ignore comments if line.startswith('#'): continue - if line == '': + if line.strip() == '': continue if line.startswith('include'): @@ -1570,7 +1582,7 @@ def set_auto_conf(self, options, config='postgresql.auto.conf', rm_options={}): for directive in current_directives: auto_conf += directive + "\n" - self.os_ops.write(auto_conf_file, auto_conf) + self.os_ops.write(path, auto_conf, truncate=True) class NodeApp: diff --git a/testgres/operations/local_ops.py b/testgres/operations/local_ops.py index fb47194f..edd8cde2 100644 --- a/testgres/operations/local_ops.py +++ b/testgres/operations/local_ops.py @@ -1,6 +1,7 @@ import getpass import os import shutil +import stat import subprocess import tempfile from shutil import rmtree @@ -8,8 +9,7 @@ import psutil from ..exceptions import ExecUtilException - -from .os_ops import OsOperations, ConnectionParams +from .os_ops import ConnectionParams, OsOperations from .os_ops import pglib try: @@ -18,20 +18,24 @@ from distutils.spawn import find_executable CMD_TIMEOUT_SEC = 60 +error_markers = [b'error', b'Permission denied', b'fatal'] class LocalOperations(OsOperations): - def __init__(self, conn_params: ConnectionParams = ConnectionParams()): - super().__init__(conn_params.username) + def __init__(self, conn_params=None): + if conn_params is None: + conn_params = ConnectionParams() + super(LocalOperations, self).__init__(conn_params.username) self.conn_params = conn_params self.host = conn_params.host self.ssh_key = None + self.remote = False self.username = conn_params.username or self.get_user() # Command execution def exec_command(self, cmd, wait_exit=False, verbose=False, - expect_error=False, encoding=None, shell=True, text=False, - input=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): + expect_error=False, encoding=None, shell=False, text=False, + input=None, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, proc=None): """ Execute a command in a subprocess. @@ -49,9 +53,6 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, - proc: The process to use for subprocess creation. :return: The output of the subprocess. """ - if isinstance(cmd, list): - cmd = ' '.join(item.decode('utf-8') if isinstance(item, bytes) else item for item in cmd) - if os.name == 'nt': with tempfile.NamedTemporaryFile() as buf: process = subprocess.Popen(cmd, stdout=buf, stderr=subprocess.STDOUT) @@ -71,7 +72,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, result, error = process.communicate(input) exit_status = process.returncode - found_error = "error" in error.decode(encoding or 'utf-8').lower() + error_found = exit_status != 0 or any(marker in error for marker in error_markers) if encoding: result = result.decode(encoding) @@ -80,7 +81,7 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, if expect_error: raise Exception(result, error) - if exit_status != 0 or found_error: + if exit_status != 0 or error_found: if exit_status == 0: exit_status = 1 raise ExecUtilException(message='Utility exited with non-zero code. Error `{}`'.format(error), @@ -101,7 +102,7 @@ def find_executable(self, executable): def is_executable(self, file): # Check if the file is executable - return os.access(file, os.X_OK) + return os.stat(file).st_mode & stat.S_IXUSR def set_env(self, var_name, var_val): # Check if the directory is already in PATH @@ -116,9 +117,12 @@ def get_name(self): # Work with dirs def makedirs(self, path, remove_existing=False): - if remove_existing and os.path.exists(path): - shutil.rmtree(path) - os.makedirs(path, exist_ok=True) + if remove_existing: + shutil.rmtree(path, ignore_errors=True) + try: + os.makedirs(path) + except FileExistsError: + pass def rmdirs(self, path, ignore_errors=True): return rmtree(path, ignore_errors=ignore_errors) @@ -141,7 +145,7 @@ def pathsep(self): return pathsep def mkdtemp(self, prefix=None): - return tempfile.mkdtemp(prefix=prefix) + return tempfile.mkdtemp(prefix='{}'.format(prefix)) def mkstemp(self, prefix=None): fd, filename = tempfile.mkstemp(prefix=prefix) diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index eb996f58..bdeb423a 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -17,7 +17,7 @@ sshtunnel.TUNNEL_TIMEOUT = 5.0 -error_markers = [b'error', b'Permission denied'] +error_markers = [b'error', b'Permission denied', b'fatal'] class PsUtilProcessProxy: @@ -43,6 +43,7 @@ def __init__(self, conn_params: ConnectionParams): self.host = conn_params.host self.ssh_key = conn_params.ssh_key self.ssh = self.ssh_connect() + self.remote = True self.username = conn_params.username or self.get_user() self.tunnel = None @@ -89,7 +90,7 @@ def _read_ssh_key(self): ExecUtilException(message="An error occurred while reading the ssh key: {}".format(e)) def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=False, - encoding=None, shell=True, text=False, input=None, stdout=None, + encoding=None, shell=True, text=False, input=None, stdin=None, stdout=None, stderr=None, proc=None): """ Execute a command in the SSH session. @@ -131,7 +132,11 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa if error_found: if exit_status == 0: exit_status = 1 - raise ExecUtilException(message="Utility exited with non-zero code. Error: {}".format(error.decode(encoding or 'utf-8')), + if encoding: + message = "Utility exited with non-zero code. Error: {}".format(error.decode(encoding)) + else: + message = b"Utility exited with non-zero code. Error: " + error + raise ExecUtilException(message=message, command=cmd, exit_code=exit_status, out=result) @@ -429,7 +434,7 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432, s conn = pglib.connect( host=host, # change to 'localhost' because we're connecting through a local ssh tunnel port=self.tunnel.local_bind_port, # use the local bind port set up by the tunnel - dbname=dbname, + database=dbname, user=user or self.username, password=password ) diff --git a/tests/test_remote.py b/tests/test_remote.py index ceb06ee3..3794349c 100755 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -1,3 +1,5 @@ +import os + import pytest from testgres import ExecUtilException @@ -9,9 +11,10 @@ class TestRemoteOperations: @pytest.fixture(scope="function", autouse=True) def setup(self): - conn_params = ConnectionParams(host="172.18.0.3", - username="dev", - ssh_key='../../container_files/postgres/ssh/id_ed25519') + conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') self.operations = RemoteOperations(conn_params) yield @@ -35,7 +38,7 @@ def test_exec_command_failure(self): exit_status, result, error = self.operations.exec_command(cmd, verbose=True, wait_exit=True) except ExecUtilException as e: error = e.message - assert error == 'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' + assert error == b'Utility exited with non-zero code. Error: bash: line 1: nonexistent_command: command not found\n' def test_is_executable_true(self): """ @@ -62,7 +65,7 @@ def test_makedirs_and_rmdirs_success(self): cmd = "pwd" pwd = self.operations.exec_command(cmd, wait_exit=True, encoding='utf-8').strip() - path = f"{pwd}/test_dir" + path = "{}/test_dir".format(pwd) # Test makedirs self.operations.makedirs(path) @@ -88,7 +91,7 @@ def test_makedirs_and_rmdirs_failure(self): exit_status, result, error = self.operations.rmdirs(path, verbose=True) except ExecUtilException as e: error = e.message - assert error == "Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" + assert error == b"Utility exited with non-zero code. Error: rm: cannot remove '/root/test_dir': Permission denied\n" def test_listdir(self): """ diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 448a60ca..5028bc75 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -52,9 +52,10 @@ from testgres.utils import PgVer from testgres.node import ProcessProxy, ConnectionParams -conn_params = ConnectionParams(host="172.18.0.3", - username="dev", - ssh_key='../../container_files/postgres/ssh/id_ed25519') +conn_params = ConnectionParams(host=os.getenv('RDBMS_TESTPOOL1_HOST') or '172.18.0.3', + username='dev', + ssh_key=os.getenv( + 'RDBMS_TESTPOOL_SSHKEY') or '../../container_files/postgres/ssh/id_ed25519') os_ops = RemoteOperations(conn_params) testgres_config.set_os_ops(os_ops=os_ops) @@ -148,14 +149,12 @@ def test_init_unique_system_id(self): with scoped_config(cache_initdb=True, cached_initdb_unique=True) as config: - self.assertTrue(config.cache_initdb) self.assertTrue(config.cached_initdb_unique) # spawn two nodes; ids must be different with get_remote_node().init().start() as node1, \ get_remote_node().init().start() as node2: - id1 = node1.execute(query)[0] id2 = node2.execute(query)[0] @@ -197,10 +196,10 @@ def test_restart(self): # restart, ok res = node.execute('select 1') - self.assertEqual(res, [(1, )]) + self.assertEqual(res, [(1,)]) node.restart() res = node.execute('select 2') - self.assertEqual(res, [(2, )]) + self.assertEqual(res, [(2,)]) # restart, fail with self.assertRaises(StartNodeException): @@ -262,7 +261,6 @@ def test_status(self): def test_psql(self): with get_remote_node().init().start() as node: - # check returned values (1 arg) res = node.psql('select 1') self.assertEqual(res, (0, b'1\n', b'')) @@ -306,7 +304,6 @@ def test_psql(self): def test_transactions(self): with get_remote_node().init().start() as node: - with node.connect() as con: con.begin() con.execute('create table test(val int)') @@ -316,12 +313,12 @@ def test_transactions(self): con.begin() con.execute('insert into test values (2)') res = con.execute('select * from test order by val asc') - self.assertListEqual(res, [(1, ), (2, )]) + self.assertListEqual(res, [(1,), (2,)]) con.rollback() con.begin() res = con.execute('select * from test') - self.assertListEqual(res, [(1, )]) + self.assertListEqual(res, [(1,)]) con.rollback() con.begin() @@ -330,7 +327,6 @@ def test_transactions(self): def test_control_data(self): with get_remote_node() as node: - # node is not initialized yet with self.assertRaises(ExecUtilException): node.get_control_data() @@ -344,7 +340,6 @@ def test_control_data(self): def test_backup_simple(self): with get_remote_node() as master: - # enable streaming for backups master.init(allow_streaming=True) @@ -361,7 +356,7 @@ def test_backup_simple(self): with master.backup(xlog_method='stream') as backup: with backup.spawn_primary().start() as slave: res = slave.execute('select * from test order by i asc') - self.assertListEqual(res, [(1, ), (2, ), (3, ), (4, )]) + self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) def test_backup_multiple(self): with get_remote_node() as node: @@ -369,13 +364,11 @@ def test_backup_multiple(self): with node.backup(xlog_method='fetch') as backup1, \ node.backup(xlog_method='fetch') as backup2: - self.assertNotEqual(backup1.base_dir, backup2.base_dir) with node.backup(xlog_method='fetch') as backup: with backup.spawn_primary('node1', destroy=False) as node1, \ backup.spawn_primary('node2', destroy=False) as node2: - self.assertNotEqual(node1.base_dir, node2.base_dir) def test_backup_exhaust(self): @@ -383,7 +376,6 @@ def test_backup_exhaust(self): node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup: - # exhaust backup by creating new node with backup.spawn_primary(): pass @@ -418,7 +410,7 @@ def test_replicate(self): with node.replicate().start() as replica: res = replica.execute('select 1') - self.assertListEqual(res, [(1, )]) + self.assertListEqual(res, [(1,)]) node.execute('create table test (val int)', commit=True) @@ -512,7 +504,7 @@ def test_logical_replication(self): node1.safe_psql('insert into test2 values (\'a\'), (\'b\')') sub.catchup() res = node2.execute('select * from test2') - self.assertListEqual(res, [('a', ), ('b', )]) + self.assertListEqual(res, [('a',), ('b',)]) # drop subscription sub.drop() @@ -530,12 +522,12 @@ def test_logical_replication(self): # explicitely add table with self.assertRaises(ValueError): - pub.add_tables([]) # fail + pub.add_tables([]) # fail pub.add_tables(['test2']) node1.safe_psql('insert into test2 values (\'c\')') sub.catchup() res = node2.execute('select * from test2') - self.assertListEqual(res, [('a', ), ('b', )]) + self.assertListEqual(res, [('a',), ('b',)]) @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_catchup(self): @@ -619,7 +611,7 @@ def test_dump(self): # restore dump node3.restore(filename=dump) res = node3.execute(query_select) - self.assertListEqual(res, [(1, ), (2, )]) + self.assertListEqual(res, [(1,), (2,)]) def test_users(self): with get_remote_node().init().start() as node: @@ -651,7 +643,7 @@ def test_poll_query_until(self): # check None, ok node.poll_query_until(query='create table def()', - expected=None) # returns nothing + expected=None) # returns nothing # check 0 rows equivalent to expected=None node.poll_query_until( @@ -697,7 +689,7 @@ def test_logging(self): 'handlers': { 'file': { 'class': 'logging.FileHandler', - 'filename': logfile, + 'filename': logfile.name, 'formatter': 'base_format', 'level': logging.DEBUG, }, @@ -708,7 +700,7 @@ def test_logging(self): }, }, 'root': { - 'handlers': ('file', ), + 'handlers': ('file',), 'level': 'DEBUG', }, } @@ -730,7 +722,7 @@ def test_logging(self): time.sleep(0.1) # check that master's port is found - with open(logfile, 'r') as log: + with open(logfile.name, 'r') as log: lines = log.readlines() self.assertTrue(any(node_name in s for s in lines)) @@ -743,7 +735,6 @@ def test_logging(self): @unittest.skipUnless(util_exists('pgbench'), 'might be missing') def test_pgbench(self): with get_remote_node().init().start() as node: - # initialize pgbench DB and run benchmarks node.pgbench_init(scale=2, foreign_keys=True, options=['-q']).pgbench_run(time=2) @@ -764,7 +755,6 @@ def test_pg_config(self): c1 = get_pg_config() # modify setting for this scope with scoped_config(cache_pg_config=False) as config: - # sanity check for value self.assertFalse(config.cache_pg_config) @@ -796,7 +786,6 @@ def test_config_stack(self): self.assertEqual(c1.cached_initdb_dir, d1) with scoped_config(cached_initdb_dir=d2) as c2: - stack_size = len(testgres.config.config_stack) # try to break a stack @@ -830,7 +819,6 @@ def test_unix_sockets(self): def test_auto_name(self): with get_remote_node().init(allow_streaming=True).start() as m: with m.replicate().start() as r: - # check that nodes are running self.assertTrue(m.status()) self.assertTrue(r.status()) From 0796bc4b334745e59857d2d0a8f8de2d107023a3 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 26 Jul 2023 09:32:30 +0200 Subject: [PATCH 21/23] PBCKP-152 - test_restore_target_time cut --- testgres/node.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/testgres/node.py b/testgres/node.py index a146b08d..244f3c1f 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -659,7 +659,7 @@ def get_control_data(self): return out_dict - def slow_start(self, replica=False, dbname='template1', username=default_username()): + def slow_start(self, replica=False, dbname='template1', username=default_username(), max_attempts=0): """ Starts the PostgreSQL instance and then polls the instance until it reaches the expected state (primary or replica). The state is checked @@ -670,6 +670,7 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam username: replica: If True, waits for the instance to be in recovery (i.e., replica mode). If False, waits for the instance to be in primary mode. Default is False. + max_attempts: """ self.start() @@ -684,7 +685,8 @@ def slow_start(self, replica=False, dbname='template1', username=default_usernam suppress={InternalError, QueryException, ProgrammingError, - OperationalError}) + OperationalError}, + max_attempts=max_attempts) def start(self, params=[], wait=True): """ @@ -719,7 +721,6 @@ def start(self, params=[], wait=True): msg = 'Cannot start node' files = self._collect_special_files() raise_from(StartNodeException(msg, files), e) - self._maybe_start_logger() self.is_started = True return self @@ -1139,9 +1140,9 @@ def poll_query_until(self, # sanity checks assert max_attempts >= 0 assert sleep_time > 0 - attempts = 0 while max_attempts == 0 or attempts < max_attempts: + print(f"Pooling {attempts}") try: res = self.execute(dbname=dbname, query=query, From 0f14034bdef296144016a53bae1e601bc243cc08 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Fri, 28 Jul 2023 09:38:08 +0200 Subject: [PATCH 22/23] PBCKP-152 - node set listen address --- testgres/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testgres/node.py b/testgres/node.py index 244f3c1f..fb259b89 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -530,7 +530,7 @@ def get_auth_method(t): self.append_conf(fsync=fsync, max_worker_processes=MAX_WORKER_PROCESSES, log_statement=log_statement, - listen_addresses='*', + listen_addresses=self.host, port=self.port) # yapf:disable # common replication settings From 12aa7bab9df4a6a2c1e4dca4df5d7f6f17bdba41 Mon Sep 17 00:00:00 2001 From: "v.shepard" Date: Wed, 2 Aug 2023 00:50:33 +0200 Subject: [PATCH 23/23] Add info about remote mode in README.md --- README.md | 27 +++++++++ testgres/__init__.py | 3 +- testgres/api.py | 12 +++- testgres/node.py | 6 +- testgres/operations/remote_ops.py | 7 ++- testgres/utils.py | 13 ++++- tests/README.md | 29 ++++++++++ tests/test_simple_remote.py | 92 +++++++++++++++---------------- 8 files changed, 130 insertions(+), 59 deletions(-) diff --git a/README.md b/README.md index 6b26ba96..29b974dc 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,33 @@ with testgres.get_new_node().init() as master: Note that `default_conf()` is called by `init()` function; both of them overwrite the configuration file, which means that they should be called before `append_conf()`. +### Remote mode +Testgres supports the creation of PostgreSQL nodes on a remote host. This is useful when you want to run distributed tests involving multiple nodes spread across different machines. + +To use this feature, you need to use the RemoteOperations class. +Here is an example of how you might set this up: + +```python +from testgres import ConnectionParams, RemoteOperations, TestgresConfig, get_remote_node + +# Set up connection params +conn_params = ConnectionParams( + host='your_host', # replace with your host + username='user_name', # replace with your username + ssh_key='path_to_ssh_key' # replace with your SSH key path +) +os_ops = RemoteOperations(conn_params) + +# Add remote testgres config before test +TestgresConfig.set_os_ops(os_ops=os_ops) + +# Proceed with your test +def test_basic_query(self): + with get_remote_node(conn_params=conn_params) as node: + node.init().start() + res = node.execute('SELECT 1') + self.assertEqual(res, [(1,)]) +``` ## Authors diff --git a/testgres/__init__.py b/testgres/__init__.py index ce2636b4..b63c7df1 100644 --- a/testgres/__init__.py +++ b/testgres/__init__.py @@ -1,4 +1,4 @@ -from .api import get_new_node +from .api import get_new_node, get_remote_node from .backup import NodeBackup from .config import \ @@ -52,6 +52,7 @@ __all__ = [ "get_new_node", + "get_remote_node", "NodeBackup", "TestgresConfig", "configure_testgres", "scoped_config", "push_config", "pop_config", "NodeConnection", "DatabaseError", "InternalError", "ProgrammingError", "OperationalError", diff --git a/testgres/api.py b/testgres/api.py index 8f553529..e4b1cdd5 100644 --- a/testgres/api.py +++ b/testgres/api.py @@ -37,10 +37,18 @@ def get_new_node(name=None, base_dir=None, **kwargs): """ Simply a wrapper around :class:`.PostgresNode` constructor. See :meth:`.PostgresNode.__init__` for details. + """ + # NOTE: leave explicit 'name' and 'base_dir' for compatibility + return PostgresNode(name=name, base_dir=base_dir, **kwargs) + + +def get_remote_node(name=None, conn_params=None): + """ + Simply a wrapper around :class:`.PostgresNode` constructor for remote node. + See :meth:`.PostgresNode.__init__` for details. For remote connection you can add the next parameter: conn_params = ConnectionParams(host='127.0.0.1', ssh_key=None, username=default_username()) """ - # NOTE: leave explicit 'name' and 'base_dir' for compatibility - return PostgresNode(name=name, base_dir=base_dir, **kwargs) + return get_new_node(name=name, conn_params=conn_params) diff --git a/testgres/node.py b/testgres/node.py index fb259b89..6483514b 100644 --- a/testgres/node.py +++ b/testgres/node.py @@ -146,8 +146,9 @@ def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionP # basic self.name = name or generate_app_name() - - if conn_params.ssh_key: + if testgres_config.os_ops: + self.os_ops = testgres_config.os_ops + elif conn_params.ssh_key: self.os_ops = RemoteOperations(conn_params) else: self.os_ops = LocalOperations(conn_params) @@ -157,7 +158,6 @@ def __init__(self, name=None, port=None, base_dir=None, conn_params: ConnectionP self.host = self.os_ops.host self.ssh_key = self.os_ops.ssh_key - testgres_config.os_ops = self.os_ops # defaults for __exit__() self.cleanup_on_good_exit = testgres_config.node_cleanup_on_good_exit self.cleanup_on_bad_exit = testgres_config.node_cleanup_on_bad_exit diff --git a/testgres/operations/remote_ops.py b/testgres/operations/remote_ops.py index bdeb423a..6815c7f1 100644 --- a/testgres/operations/remote_ops.py +++ b/testgres/operations/remote_ops.py @@ -17,7 +17,7 @@ sshtunnel.TUNNEL_TIMEOUT = 5.0 -error_markers = [b'error', b'Permission denied', b'fatal'] +error_markers = [b'error', b'Permission denied', b'fatal', b'No such file or directory'] class PsUtilProcessProxy: @@ -203,7 +203,10 @@ def makedirs(self, path, remove_existing=False): cmd = "rm -rf {} && mkdir -p {}".format(path, path) else: cmd = "mkdir -p {}".format(path) - exit_status, result, error = self.exec_command(cmd, verbose=True) + try: + exit_status, result, error = self.exec_command(cmd, verbose=True) + except ExecUtilException as e: + raise Exception("Couldn't create dir {} because of error {}".format(path, e.message)) if exit_status != 0: raise Exception("Couldn't create dir {} because of error {}".format(path, error)) return result diff --git a/testgres/utils.py b/testgres/utils.py index 1772d748..58e18deb 100644 --- a/testgres/utils.py +++ b/testgres/utils.py @@ -87,9 +87,12 @@ def get_bin_path(filename): # check if it's already absolute if os.path.isabs(filename): return filename + if tconf.os_ops.remote: + pg_config = os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = os.environ.get("PG_CONFIG") - # try PG_CONFIG - get from local machine - pg_config = os.environ.get("PG_CONFIG") if pg_config: bindir = get_pg_config()["BINDIR"] return os.path.join(bindir, filename) @@ -139,7 +142,11 @@ def cache_pg_config_data(cmd): return _pg_config_data # try specified pg_config path or PG_CONFIG - pg_config = pg_config_path or os.environ.get("PG_CONFIG") + if tconf.os_ops.remote: + pg_config = pg_config_path or os.environ.get("PG_CONFIG_REMOTE") or os.environ.get("PG_CONFIG") + else: + # try PG_CONFIG - get from local machine + pg_config = pg_config_path or os.environ.get("PG_CONFIG") if pg_config: return cache_pg_config_data(pg_config) diff --git a/tests/README.md b/tests/README.md index a6d50992..d89efc7e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -27,3 +27,32 @@ export PYTHON_VERSION=3 # or 2 # Run tests ./run_tests.sh ``` + + +#### Remote host tests + +1. Start remote host or docker container +2. Make sure that you run ssh +```commandline +sudo apt-get install openssh-server +sudo systemctl start sshd +``` +3. You need to connect to the remote host at least once to add it to the known hosts file +4. Generate ssh keys +5. Set up params for tests + + +```commandline +conn_params = ConnectionParams( + host='remote_host', + username='username', + ssh_key=/path/to/your/ssh/key' +) +os_ops = RemoteOperations(conn_params) +``` +If you have different path to `PG_CONFIG` on your local and remote host you can set up `PG_CONFIG_REMOTE`, this value will be +using during work with remote host. + +`test_remote` - Tests for RemoteOperations class. + +`test_simple_remote` - Tests that create node and check it. The same as `test_simple`, but for remote node. \ No newline at end of file diff --git a/tests/test_simple_remote.py b/tests/test_simple_remote.py index 5028bc75..e8386383 100755 --- a/tests/test_simple_remote.py +++ b/tests/test_simple_remote.py @@ -35,7 +35,7 @@ NodeStatus, \ ProcessType, \ IsolationLevel, \ - get_new_node, \ + get_remote_node, \ RemoteOperations from testgres import \ @@ -94,23 +94,19 @@ def removing(f): os_ops.rmdirs(f, ignore_errors=True) -def get_remote_node(name=None): - return get_new_node(name=name, conn_params=conn_params) - - class TestgresRemoteTests(unittest.TestCase): def test_node_repr(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: pattern = r"PostgresNode\(name='.+', port=.+, base_dir='.+'\)" self.assertIsNotNone(re.match(pattern, str(node))) def test_custom_init(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # enable page checksums node.init(initdb_params=['-k']).start() - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init( allow_streaming=True, initdb_params=['--auth-local=reject', '--auth-host=reject']) @@ -125,13 +121,13 @@ def test_custom_init(self): self.assertFalse(any('trust' in s for s in lines)) def test_double_init(self): - with get_remote_node().init() as node: + with get_remote_node(conn_params=conn_params).init() as node: # can't initialize node more than once with self.assertRaises(InitNodeException): node.init() def test_init_after_cleanup(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start().execute('select 1') node.cleanup() node.init().start().execute('select 1') @@ -144,7 +140,7 @@ def test_init_unique_system_id(self): query = 'select system_identifier from pg_control_system()' with scoped_config(cache_initdb=False): - with get_remote_node().init().start() as node0: + with get_remote_node(conn_params=conn_params).init().start() as node0: id0 = node0.execute(query)[0] with scoped_config(cache_initdb=True, @@ -153,8 +149,8 @@ def test_init_unique_system_id(self): self.assertTrue(config.cached_initdb_unique) # spawn two nodes; ids must be different - with get_remote_node().init().start() as node1, \ - get_remote_node().init().start() as node2: + with get_remote_node(conn_params=conn_params).init().start() as node1, \ + get_remote_node(conn_params=conn_params).init().start() as node2: id1 = node1.execute(query)[0] id2 = node2.execute(query)[0] @@ -164,7 +160,7 @@ def test_init_unique_system_id(self): def test_node_exit(self): with self.assertRaises(QueryException): - with get_remote_node().init() as node: + with get_remote_node(conn_params=conn_params).init() as node: base_dir = node.base_dir node.safe_psql('select 1') @@ -172,26 +168,26 @@ def test_node_exit(self): self.assertTrue(os_ops.path_exists(base_dir)) os_ops.rmdirs(base_dir, ignore_errors=True) - with get_remote_node().init() as node: + with get_remote_node(conn_params=conn_params).init() as node: base_dir = node.base_dir # should have been removed by default self.assertFalse(os_ops.path_exists(base_dir)) def test_double_start(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: # can't start node more than once node.start() self.assertTrue(node.is_started) def test_uninitialized_start(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # node is not initialized yet with self.assertRaises(StartNodeException): node.start() def test_restart(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() # restart, ok @@ -207,7 +203,7 @@ def test_restart(self): node.restart() def test_reload(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() # change client_min_messages and save old value @@ -223,7 +219,7 @@ def test_reload(self): self.assertNotEqual(cmm_old, cmm_new) def test_pg_ctl(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() status = node.pg_ctl(['status']) @@ -235,7 +231,7 @@ def test_status(self): self.assertFalse(NodeStatus.Uninitialized) # check statuses after each operation - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: self.assertEqual(node.pid, 0) self.assertEqual(node.status(), NodeStatus.Uninitialized) @@ -260,7 +256,7 @@ def test_status(self): self.assertEqual(node.status(), NodeStatus.Uninitialized) def test_psql(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: # check returned values (1 arg) res = node.psql('select 1') self.assertEqual(res, (0, b'1\n', b'')) @@ -303,7 +299,7 @@ def test_psql(self): node.safe_psql('select 1') def test_transactions(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: with node.connect() as con: con.begin() con.execute('create table test(val int)') @@ -326,7 +322,7 @@ def test_transactions(self): con.commit() def test_control_data(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # node is not initialized yet with self.assertRaises(ExecUtilException): node.get_control_data() @@ -339,7 +335,7 @@ def test_control_data(self): self.assertTrue(any('pg_control' in s for s in data.keys())) def test_backup_simple(self): - with get_remote_node() as master: + with get_remote_node(conn_params=conn_params) as master: # enable streaming for backups master.init(allow_streaming=True) @@ -359,7 +355,7 @@ def test_backup_simple(self): self.assertListEqual(res, [(1,), (2,), (3,), (4,)]) def test_backup_multiple(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup1, \ @@ -372,7 +368,7 @@ def test_backup_multiple(self): self.assertNotEqual(node1.base_dir, node2.base_dir) def test_backup_exhaust(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.backup(xlog_method='fetch') as backup: @@ -385,7 +381,7 @@ def test_backup_exhaust(self): backup.spawn_primary() def test_backup_wrong_xlog_method(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with self.assertRaises(BackupException, @@ -393,7 +389,7 @@ def test_backup_wrong_xlog_method(self): node.backup(xlog_method='wrong') def test_pg_ctl_wait_option(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start(wait=False) while True: try: @@ -405,7 +401,7 @@ def test_pg_ctl_wait_option(self): pass def test_replicate(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.replicate().start() as replica: @@ -421,7 +417,7 @@ def test_replicate(self): @unittest.skipUnless(pg_version_ge('9.6'), 'requires 9.6+') def test_synchronous_replication(self): - with get_remote_node() as master: + with get_remote_node(conn_params=conn_params) as master: old_version = not pg_version_ge('9.6') master.init(allow_streaming=True).start() @@ -462,7 +458,7 @@ def test_synchronous_replication(self): @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_replication(self): - with get_remote_node() as node1, get_remote_node() as node2: + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: node1.init(allow_logical=True) node1.start() node2.init().start() @@ -532,7 +528,7 @@ def test_logical_replication(self): @unittest.skipUnless(pg_version_ge('10'), 'requires 10+') def test_logical_catchup(self): """ Runs catchup for 100 times to be sure that it is consistent """ - with get_remote_node() as node1, get_remote_node() as node2: + with get_remote_node(conn_params=conn_params) as node1, get_remote_node(conn_params=conn_params) as node2: node1.init(allow_logical=True) node1.start() node2.init().start() @@ -557,12 +553,12 @@ def test_logical_catchup(self): @unittest.skipIf(pg_version_ge('10'), 'requires <10') def test_logical_replication_fail(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: with self.assertRaises(InitNodeException): node.init(allow_logical=True) def test_replication_slots(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() with node.replicate(slot='slot1').start() as replica: @@ -573,7 +569,7 @@ def test_replication_slots(self): node.replicate(slot='slot1') def test_incorrect_catchup(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(allow_streaming=True).start() # node has no master, can't catch up @@ -581,7 +577,7 @@ def test_incorrect_catchup(self): node.catchup() def test_promotion(self): - with get_remote_node() as master: + with get_remote_node(conn_params=conn_params) as master: master.init().start() master.safe_psql('create table abc(id serial)') @@ -598,12 +594,12 @@ def test_dump(self): query_create = 'create table test as select generate_series(1, 2) as val' query_select = 'select * from test order by val asc' - with get_remote_node().init().start() as node1: + with get_remote_node(conn_params=conn_params).init().start() as node1: node1.execute(query_create) for format in ['plain', 'custom', 'directory', 'tar']: with removing(node1.dump(format=format)) as dump: - with get_remote_node().init().start() as node3: + with get_remote_node(conn_params=conn_params).init().start() as node3: if format == 'directory': self.assertTrue(os_ops.isdir(dump)) else: @@ -614,13 +610,13 @@ def test_dump(self): self.assertListEqual(res, [(1,), (2,)]) def test_users(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: node.psql('create role test_user login') value = node.safe_psql('select 1', username='test_user') self.assertEqual(b'1\n', value) def test_poll_query_until(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init().start() get_time = 'select extract(epoch from now())' @@ -734,7 +730,7 @@ def test_logging(self): @unittest.skipUnless(util_exists('pgbench'), 'might be missing') def test_pgbench(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: # initialize pgbench DB and run benchmarks node.pgbench_init(scale=2, foreign_keys=True, options=['-q']).pgbench_run(time=2) @@ -801,7 +797,7 @@ def test_config_stack(self): self.assertEqual(TestgresConfig.cached_initdb_dir, d0) def test_unix_sockets(self): - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: node.init(unix_sockets=False, allow_streaming=True) node.start() @@ -817,7 +813,7 @@ def test_unix_sockets(self): self.assertEqual(res_psql, b'1\n') def test_auto_name(self): - with get_remote_node().init(allow_streaming=True).start() as m: + with get_remote_node(conn_params=conn_params).init(allow_streaming=True).start() as m: with m.replicate().start() as r: # check that nodes are running self.assertTrue(m.status()) @@ -854,7 +850,7 @@ def test_file_tail(self): self.assertEqual(lines[0], s3) def test_isolation_levels(self): - with get_remote_node().init().start() as node: + with get_remote_node(conn_params=conn_params).init().start() as node: with node.connect() as con: # string levels con.begin('Read Uncommitted').commit() @@ -876,7 +872,7 @@ def test_ports_management(self): # check that no ports have been bound yet self.assertEqual(len(bound_ports), 0) - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: # check that we've just bound a port self.assertEqual(len(bound_ports), 1) @@ -909,7 +905,7 @@ def test_version_management(self): self.assertTrue(d > f) version = get_pg_version() - with get_remote_node() as node: + with get_remote_node(conn_params=conn_params) as node: self.assertTrue(isinstance(version, six.string_types)) self.assertTrue(isinstance(node.version, PgVer)) self.assertEqual(node.version, PgVer(version)) @@ -932,7 +928,7 @@ def test_child_pids(self): ProcessType.WalReceiver, ] - with get_remote_node().init().start() as master: + with get_remote_node(conn_params=conn_params).init().start() as master: # master node doesn't have a source walsender! with self.assertRaises(TestgresException):