2020error_markers = [b'error' , b'Permission denied' ]
2121
2222
23+ class PsUtilProcessProxy :
24+ def __init__ (self , ssh , pid ):
25+ self .ssh = ssh
26+ self .pid = pid
27+
28+ def kill (self ):
29+ command = f"kill { self .pid } "
30+ self .ssh .exec_command (command )
31+
32+ def cmdline (self ):
33+ command = f"ps -p { self .pid } -o cmd --no-headers"
34+ stdin , stdout , stderr = self .ssh .exec_command (command )
35+ cmdline = stdout .read ().decode ('utf-8' ).strip ()
36+ return cmdline .split ()
37+
38+
2339class RemoteOperations (OsOperations ):
2440 def __init__ (self , host = "127.0.0.1" , hostname = 'localhost' , port = None , ssh_key = None , username = None ):
2541 super ().__init__ (username )
@@ -71,7 +87,7 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa
7187 self .ssh = self .ssh_connect ()
7288
7389 if isinstance (cmd , list ):
74- cmd = " " .join (cmd )
90+ cmd = ' ' .join (item . decode ( 'utf-8' ) if isinstance ( item , bytes ) else item for item in cmd )
7591 if input :
7692 stdin , stdout , stderr = self .ssh .exec_command (cmd )
7793 stdin .write (input )
@@ -140,17 +156,6 @@ def is_executable(self, file):
140156 is_exec = self .exec_command (f"test -x { file } && echo OK" )
141157 return is_exec == b"OK\n "
142158
143- def add_to_path (self , new_path ):
144- pathsep = self .pathsep
145- # Check if the directory is already in PATH
146- path = self .environ ("PATH" )
147- if new_path not in path .split (pathsep ):
148- if self .remote :
149- self .exec_command (f"export PATH={ new_path } { pathsep } { path } " )
150- else :
151- os .environ ["PATH" ] = f"{ new_path } { pathsep } { path } "
152- return pathsep
153-
154159 def set_env (self , var_name : str , var_val : str ):
155160 """
156161 Set the value of an environment variable.
@@ -243,9 +248,17 @@ def mkdtemp(self, prefix=None):
243248 raise ExecUtilException ("Could not create temporary directory." )
244249
245250 def mkstemp (self , prefix = None ):
246- cmd = f"mktemp { prefix } XXXXXX"
247- filename = self .exec_command (cmd ).strip ()
248- return filename
251+ if prefix :
252+ temp_dir = self .exec_command (f"mktemp { prefix } XXXXX" , encoding = 'utf-8' )
253+ else :
254+ temp_dir = self .exec_command ("mktemp" , encoding = 'utf-8' )
255+
256+ if temp_dir :
257+ if not os .path .isabs (temp_dir ):
258+ temp_dir = os .path .join ('/home' , self .username , temp_dir .strip ())
259+ return temp_dir
260+ else :
261+ raise ExecUtilException ("Could not create temporary directory." )
249262
250263 def copytree (self , src , dst ):
251264 if not os .path .isabs (dst ):
@@ -291,7 +304,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
291304 data = data .encode (encoding )
292305 if isinstance (data , list ):
293306 # ensure each line ends with a newline
294- data = [s if s . endswith ( ' \n ' ) else s + '\n ' for s in data ]
307+ data = [( s if isinstance ( s , str ) else s . decode ( 'utf-8' )). rstrip ( ' \n ' ) + '\n ' for s in data ]
295308 tmp_file .writelines (data )
296309 else :
297310 tmp_file .write (data )
@@ -351,8 +364,8 @@ def isfile(self, remote_file):
351364
352365 def isdir (self , dirname ):
353366 cmd = f"if [ -d { dirname } ]; then echo True; else echo False; fi"
354- response = self .exec_command (cmd , encoding = 'utf-8' )
355- return response .strip () == "True"
367+ response = self .exec_command (cmd )
368+ return response .strip () == b "True"
356369
357370 def remove_file (self , filename ):
358371 cmd = f"rm { filename } "
@@ -366,16 +379,16 @@ def kill(self, pid, signal):
366379
367380 def get_pid (self ):
368381 # Get current process id
369- return self .exec_command ("echo $$" )
382+ return int ( self .exec_command ("echo $$" , encoding = 'utf-8' ) )
370383
371384 def get_remote_children (self , pid ):
372385 command = f"pgrep -P { pid } "
373386 stdin , stdout , stderr = self .ssh .exec_command (command )
374387 children = stdout .readlines ()
375- return [int (child_pid .strip ()) for child_pid in children ]
388+ return [PsUtilProcessProxy ( self . ssh , int (child_pid .strip () )) for child_pid in children ]
376389
377390 # Database control
378- def db_connect (self , dbname , user , password = None , host = "127.0.0.1" , port = 5432 ):
391+ def db_connect (self , dbname , user , password = None , host = "127.0.0.1" , port = 5432 , ssh_key = None ):
379392 """
380393 Connects to a PostgreSQL database on the remote system.
381394 Args:
@@ -389,19 +402,26 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432):
389402 This function establishes a connection to a PostgreSQL database on the remote system using the specified
390403 parameters. It returns a connection object that can be used to interact with the database.
391404 """
392- with sshtunnel .open_tunnel (
393- (host , 22 ), # Remote server IP and SSH port
394- ssh_username = self .username ,
395- ssh_pkey = self .ssh_key ,
396- remote_bind_address = (host , port ), # PostgreSQL server IP and PostgreSQL port
397- local_bind_address = ('localhost' , port ), # Local machine IP and available port
398- ):
405+ tunnel = sshtunnel .open_tunnel (
406+ (host , 22 ), # Remote server IP and SSH port
407+ ssh_username = user or self .username ,
408+ ssh_pkey = ssh_key or self .ssh_key ,
409+ remote_bind_address = (host , port ), # PostgreSQL server IP and PostgreSQL port
410+ local_bind_address = ('localhost' , port ) # Local machine IP and available port
411+ )
412+
413+ tunnel .start ()
414+
415+ try :
399416 conn = pglib .connect (
400- host = host ,
401- port = port ,
417+ host = host , # change to 'localhost' because we're connecting through a local ssh tunnel
418+ port = tunnel . local_bind_port , # use the local bind port set up by the tunnel
402419 dbname = dbname ,
403- user = user ,
420+ user = user or self . username ,
404421 password = password
405422 )
406423
407- return conn
424+ return conn
425+ except Exception as e :
426+ tunnel .stop ()
427+ raise e
0 commit comments