1- import io
21import os
32import tempfile
4- from contextlib import contextmanager
3+ from typing import Optional
54
6- from testgres .logger import log
5+ import paramiko
6+ from paramiko import SSHClient
77
8+ from logger import log
89from .os_ops import OsOperations
910from .os_ops import pglib
1011
11- import paramiko
12+ error_markers = [ b'error' , b'Permission denied' ]
1213
1314
1415class RemoteOperations (OsOperations ):
15- """
16- This class specifically supports work with Linux systems. It utilizes the SSH
17- for making connections and performing various file and directory operations, command executions,
18- environment setup and management, process control, and database connections.
19- It uses the Paramiko library for SSH connections and operations.
20-
21- Some methods are designed to work with specific Linux shell commands, and thus may not work as expected
22- on other non-Linux systems.
23-
24- Attributes:
25- - hostname (str): The remote system's hostname. Default 'localhost'.
26- - host (str): The remote system's IP address. Default '127.0.0.1'.
27- - ssh_key (str): Path to the SSH private key for authentication.
28- - username (str): Username for the remote system.
29- - ssh (paramiko.SSHClient): SSH connection to the remote system.
30- """
31-
32- def __init__ (
33- self , hostname = "localhost" , host = "127.0.0.1" , ssh_key = None , username = None
34- ):
16+ def __init__ (self , hostname = "localhost" , host = "127.0.0.1" , ssh_key = None , username = None ):
3517 super ().__init__ (username )
36- self .hostname = hostname
3718 self .host = host
3819 self .ssh_key = ssh_key
3920 self .remote = True
40- self .ssh = self .connect ()
21+ self .ssh = self .ssh_connect ()
4122 self .username = username or self .get_user ()
4223
4324 def __del__ (self ):
4425 if self .ssh :
4526 self .ssh .close ()
4627
47- @contextmanager
48- def ssh_connect (self ):
28+ def ssh_connect (self ) -> Optional [SSHClient ]:
4929 if not self .remote :
50- yield None
30+ return None
5131 else :
32+ key = self ._read_ssh_key ()
33+ ssh = paramiko .SSHClient ()
34+ ssh .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
35+ ssh .connect (self .host , username = self .username , pkey = key )
36+ return ssh
37+
38+ def _read_ssh_key (self ):
39+ try :
5240 with open (self .ssh_key , "r" ) as f :
5341 key_data = f .read ()
5442 if "BEGIN OPENSSH PRIVATE KEY" in key_data :
5543 key = paramiko .Ed25519Key .from_private_key_file (self .ssh_key )
5644 else :
5745 key = paramiko .RSAKey .from_private_key_file (self .ssh_key )
46+ return key
47+ except FileNotFoundError :
48+ log .error (f"No such file or directory: '{ self .ssh_key } '" )
49+ except Exception as e :
50+ log .error (f"An error occurred while reading the ssh key: { e } " )
5851
59- with paramiko .SSHClient () as ssh :
60- ssh .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
61- ssh .connect (self .host , username = self .username , pkey = key )
62- yield ssh
63-
64- def connect (self ):
65- with self .ssh_connect () as ssh :
66- return ssh
67-
68- # Command execution
69- def exec_command (self , cmd , wait_exit = False , verbose = False ,
70- expect_error = False , encoding = None , shell = True , text = False ,
71- input = None , stdout = None , stderr = None , proc = None ):
52+ def exec_command (self , cmd : str , wait_exit = False , verbose = False , expect_error = False ,
53+ encoding = None , shell = True , text = False , input = None , stdout = None ,
54+ stderr = None , proc = None ):
55+ """
56+ Execute a command in the SSH session.
57+ Args:
58+ - cmd (str): The command to be executed.
59+ """
7260 if isinstance (cmd , list ):
7361 cmd = " " .join (cmd )
74- log .debug (f"os_ops.exec_command: `{ cmd } `; remote={ self .remote } " )
75- # Source global profile file + execute command
7662 try :
77- cmd = f"source /etc/profile.d/custom.sh; { cmd } "
78- with self .ssh_connect () as ssh :
79- if input :
80- # encode input and feed it to stdin
81- stdin , stdout , stderr = ssh .exec_command (cmd )
82- stdin .write (input )
83- stdin .flush ()
84- else :
85- stdin , stdout , stderr = ssh .exec_command (cmd )
86- exit_status = 0
87- if wait_exit :
88- exit_status = stdout .channel .recv_exit_status ()
89- if encoding :
90- result = stdout .read ().decode (encoding )
91- error = stderr .read ().decode (encoding )
92- else :
93- # Save as binary string
94- result = io .BytesIO (stdout .read ()).getvalue ()
95- error = io .BytesIO (stderr .read ()).getvalue ()
96- error_str = stderr .read ()
63+ if input :
64+ stdin , stdout , stderr = self .ssh .exec_command (cmd )
65+ stdin .write (input .encode ("utf-8" ))
66+ stdin .flush ()
67+ else :
68+ stdin , stdout , stderr = self .ssh .exec_command (cmd )
69+ exit_status = 0
70+ if wait_exit :
71+ exit_status = stdout .channel .recv_exit_status ()
72+
73+ if encoding :
74+ result = stdout .read ().decode (encoding )
75+ error = stderr .read ().decode (encoding )
76+ else :
77+ result = stdout .read ()
78+ error = stderr .read ()
9779
9880 if expect_error :
9981 raise Exception (result , error )
100- if exit_status != 0 or 'error' in error_str :
82+
83+ if encoding :
84+ error_found = exit_status != 0 or any (
85+ marker .decode (encoding ) in error for marker in error_markers )
86+ else :
87+ error_found = exit_status != 0 or any (
88+ marker in error for marker in error_markers )
89+
90+ if error_found :
10191 log .error (
10292 f"Problem in executing command: `{ cmd } `\n error: { error } \n exit_code: { exit_status } "
10393 )
94+ if exit_status == 0 :
95+ exit_status = 1
10496
10597 if verbose :
10698 return exit_status , result , error
@@ -112,7 +104,12 @@ def exec_command(self, cmd, wait_exit=False, verbose=False,
112104 return None
113105
114106 # Environment setup
115- def environ (self , var_name ):
107+ def environ (self , var_name : str ) -> str :
108+ """
109+ Get the value of an environment variable.
110+ Args:
111+ - var_name (str): The name of the environment variable.
112+ """
116113 cmd = f"echo ${ var_name } "
117114 return self .exec_command (cmd ).strip ()
118115
@@ -131,7 +128,8 @@ def find_executable(self, executable):
131128
132129 def is_executable (self , file ):
133130 # Check if the file is executable
134- return self .exec_command (f"test -x { file } && echo OK" ) == "OK\n "
131+ is_exec = self .exec_command (f"test -x { file } && echo OK" )
132+ return is_exec == b"OK\n "
135133
136134 def add_to_path (self , new_path ):
137135 pathsep = self .pathsep
@@ -144,8 +142,13 @@ def add_to_path(self, new_path):
144142 os .environ ["PATH" ] = f"{ new_path } { pathsep } { path } "
145143 return pathsep
146144
147- def set_env (self , var_name , var_val ):
148- # Check if the directory is already in PATH
145+ def set_env (self , var_name : str , var_val : str ) -> None :
146+ """
147+ Set the value of an environment variable.
148+ Args:
149+ - var_name (str): The name of the environment variable.
150+ - var_val (str): The value to be set for the environment variable.
151+ """
149152 return self .exec_command (f"export { var_name } ={ var_val } " )
150153
151154 # Get environment variables
@@ -158,22 +161,47 @@ def get_name(self):
158161
159162 # Work with dirs
160163 def makedirs (self , path , remove_existing = False ):
164+ """
165+ Create a directory in the remote server.
166+ Args:
167+ - path (str): The path to the directory to be created.
168+ - remove_existing (bool): If True, the existing directory at the path will be removed.
169+ """
161170 if remove_existing :
162171 cmd = f"rm -rf { path } && mkdir -p { path } "
163172 else :
164173 cmd = f"mkdir -p { path } "
165- return self .exec_command (cmd )
174+ exit_status , result , error = self .exec_command (cmd , verbose = True )
175+ if exit_status != 0 :
176+ raise Exception (f"Couldn't create dir { path } because of error { error } " )
177+ return result
166178
167- def rmdirs (self , path , ignore_errors = True ):
179+ def rmdirs (self , path , verbose = False , ignore_errors = True ):
180+ """
181+ Remove a directory in the remote server.
182+ Args:
183+ - path (str): The path to the directory to be removed.
184+ - verbose (bool): If True, return exit status, result, and error.
185+ - ignore_errors (bool): If True, do not raise error if directory does not exist.
186+ """
168187 cmd = f"rm -rf { path } "
169- return self .exec_command (cmd )
188+ exit_status , result , error = self .exec_command (cmd , verbose = True )
189+ if verbose :
190+ return exit_status , result , error
191+ else :
192+ return result
170193
171194 def listdir (self , path ):
195+ """
196+ List all files and directories in a directory.
197+ Args:
198+ path (str): The path to the directory.
199+ """
172200 result = self .exec_command (f"ls { path } " )
173201 return result .splitlines ()
174202
175203 def path_exists (self , path ):
176- result = self .exec_command (f"test -e { path } ; echo $?" )
204+ result = self .exec_command (f"test -e { path } ; echo $?" , encoding = 'utf-8' )
177205 return int (result .strip ()) == 0
178206
179207 @property
@@ -188,7 +216,12 @@ def pathsep(self):
188216 return pathsep
189217
190218 def mkdtemp (self , prefix = None ):
191- temp_dir = self .exec_command (f"mkdtemp -d { prefix } " )
219+ """
220+ Creates a temporary directory in the remote server.
221+ Args:
222+ prefix (str): The prefix of the temporary directory name.
223+ """
224+ temp_dir = self .exec_command (f"mkdtemp -d { prefix } " , encoding = 'utf-8' )
192225 return temp_dir .strip ()
193226
194227 def mkstemp (self , prefix = None ):
@@ -200,18 +233,19 @@ def copytree(self, src, dst):
200233 return self .exec_command (f"cp -r { src } { dst } " )
201234
202235 # Work with files
203- def write (self , filename , data , truncate = False , binary = False , read_and_write = False ):
236+ def write (self , filename , data , truncate = False , binary = False , read_and_write = False , encoding = 'utf-8' ):
204237 """
205238 Write data to a file on a remote host
239+
206240 Args:
207- filename: The file path where the data will be written.
208- data : The data to be written to the file.
209- truncate: If True, the file will be truncated before writing ('w' or 'wb' option);
210- if False (default), data will be appended ('a' or 'ab' option).
211- binary: If True, the data will be written in binary mode ('wb' or 'ab' option);
212- if False (default), the data will be written in text mode ('w' or 'a' option).
213- read_and_write: If True, the file will be opened with read and write permissions ('r+' option);
214- if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option)
241+ - filename (str) : The file path where the data will be written.
242+ - data (bytes or str) : The data to be written to the file.
243+ - truncate (bool) : If True, the file will be truncated before writing ('w' or 'wb' option);
244+ if False (default), data will be appended ('a' or 'ab' option).
245+ - binary (bool) : If True, the data will be written in binary mode ('wb' or 'ab' option);
246+ if False (default), the data will be written in text mode ('w' or 'a' option).
247+ - read_and_write (bool) : If True, the file will be opened with read and write permissions ('r+' option);
248+ if False (default), only write permission will be used ('w', 'a', 'wb', or 'ab' option).
215249 """
216250 mode = "wb" if binary else "w"
217251 if not truncate :
@@ -220,15 +254,18 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
220254 mode = "r+b" if binary else "r+"
221255
222256 with tempfile .NamedTemporaryFile (mode = mode ) as tmp_file :
223- if isinstance (data , list ):
224- tmp_file .writelines (data )
225- else :
226- tmp_file .write (data )
257+ if isinstance (data , bytes ) and not binary :
258+ data = data .decode (encoding )
259+ elif isinstance (data , str ) and binary :
260+ data = data .encode (encoding )
261+
262+ tmp_file .write (data )
227263 tmp_file .flush ()
228264
229- sftp = self .ssh .open_sftp ()
230- sftp .put (tmp_file .name , filename )
231- sftp .close ()
265+ with self .ssh_connect () as ssh :
266+ sftp = ssh .open_sftp ()
267+ sftp .put (tmp_file .name , filename )
268+ sftp .close ()
232269
233270 def touch (self , filename ):
234271 """
@@ -281,8 +318,29 @@ def get_pid(self):
281318 return self .exec_command ("echo $$" )
282319
283320 # Database control
284- def db_connect (self , dbname , user , password = None , host = "localhost" , port = 5432 ):
285- local_port = self .ssh .forward_remote_port (host , port )
321+ def db_connect (self , dbname , user , password = None , host = "127.0.0.1" , hostname = "localhost" , port = 5432 ):
322+ """
323+ Connects to a PostgreSQL database on the remote system.
324+ Args:
325+ - dbname (str): The name of the database to connect to.
326+ - user (str): The username for the database connection.
327+ - password (str, optional): The password for the database connection. Defaults to None.
328+ - host (str, optional): The IP address of the remote system. Defaults to "127.0.0.1".
329+ - hostname (str, optional): The hostname of the remote system. Defaults to "localhost".
330+ - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
331+
332+ This function establishes a connection to a PostgreSQL database on the remote system using the specified
333+ parameters. It returns a connection object that can be used to interact with the database.
334+ """
335+ transport = self .ssh .get_transport ()
336+ local_port = 9090 # or any other available port
337+
338+ transport .open_channel (
339+ 'direct-tcpip' ,
340+ (hostname , port ),
341+ (host , local_port )
342+ )
343+
286344 conn = pglib .connect (
287345 host = host ,
288346 port = local_port ,
@@ -291,3 +349,4 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
291349 password = password ,
292350 )
293351 return conn
352+
0 commit comments