1- import logging
21import os
2+ import socket
33import subprocess
44import tempfile
55import platform
@@ -45,47 +45,44 @@ def __init__(self, conn_params: ConnectionParams):
4545 self .conn_params = conn_params
4646 self .host = conn_params .host
4747 self .ssh_key = conn_params .ssh_key
48+ self .port = conn_params .port
49+ self .ssh_args = []
4850 if self .ssh_key :
49- self .ssh_cmd = ["-i" , self .ssh_key ]
50- else :
51- self .ssh_cmd = []
51+ self .ssh_args + = ["-i" , self .ssh_key ]
52+ if self . port :
53+ self .ssh_args + = ["-p" , self . port ]
5254 self .remote = True
5355 self .username = conn_params .username
5456 self .ssh_dest = f"{ self .username } @{ self .host } " if self .username else self .host
5557 self .add_known_host (self .host )
5658 self .tunnel_process = None
59+ self .tunnel_port = None
5760
5861 def __enter__ (self ):
5962 return self
6063
6164 def __exit__ (self , exc_type , exc_val , exc_tb ):
6265 self .close_ssh_tunnel ()
6366
64- def establish_ssh_tunnel (self , local_port , remote_port ):
65- """
66- Establish an SSH tunnel from a local port to a remote PostgreSQL port.
67- """
68- ssh_cmd = ['-N' , '-L' , f"{ local_port } :localhost:{ remote_port } " ]
69- self .tunnel_process = self .exec_command (ssh_cmd , get_process = True , timeout = 300 )
67+ @staticmethod
68+ def is_port_open (host , port ):
69+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as sock :
70+ sock .settimeout (1 ) # Таймаут для попытки соединения
71+ try :
72+ sock .connect ((host , port ))
73+ return True
74+ except socket .error :
75+ return False
7076
7177 def close_ssh_tunnel (self ):
72- if hasattr ( self , ' tunnel_process' ) :
78+ if self . tunnel_process :
7379 self .tunnel_process .terminate ()
7480 self .tunnel_process .wait ()
81+ print ("SSH tunnel closed." )
7582 del self .tunnel_process
7683 else :
7784 print ("No active tunnel to close." )
7885
79- def add_known_host (self , host ):
80- known_hosts_path = os .path .expanduser ("~/.ssh/known_hosts" )
81- cmd = 'ssh-keyscan -H %s >> %s' % (host , known_hosts_path )
82-
83- try :
84- subprocess .check_call (cmd , shell = True )
85- logging .info ("Successfully added %s to known_hosts." % host )
86- except subprocess .CalledProcessError as e :
87- raise Exception ("Failed to add %s to known_hosts. Error: %s" % (host , str (e )))
88-
8986 def exec_command (self , cmd , wait_exit = False , verbose = False , expect_error = False ,
9087 encoding = None , shell = True , text = False , input = None , stdin = None , stdout = None ,
9188 stderr = None , get_process = None , timeout = None ):
@@ -96,9 +93,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
9693 """
9794 ssh_cmd = []
9895 if isinstance (cmd , str ):
99- ssh_cmd = ['ssh' , self . ssh_dest ] + self .ssh_cmd + [cmd ]
96+ ssh_cmd = ['ssh' ] + self .ssh_args + [self . ssh_dest , cmd ]
10097 elif isinstance (cmd , list ):
101- ssh_cmd = ['ssh' , self .ssh_dest ] + self .ssh_cmd + cmd
98+ ssh_cmd = ['ssh' ] + self .ssh_args + [ self .ssh_dest ] + cmd
10299 process = subprocess .Popen (ssh_cmd , stdin = subprocess .PIPE , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
103100 if get_process :
104101 return process
@@ -243,9 +240,9 @@ def mkdtemp(self, prefix=None):
243240 - prefix (str): The prefix of the temporary directory name.
244241 """
245242 if prefix :
246- command = ["ssh" ] + self .ssh_cmd + [self .ssh_dest , f"mktemp -d { prefix } XXXXX" ]
243+ command = ["ssh" ] + self .ssh_args + [self .ssh_dest , f"mktemp -d { prefix } XXXXX" ]
247244 else :
248- command = ["ssh" ] + self .ssh_cmd + [self .ssh_dest , "mktemp -d" ]
245+ command = ["ssh" ] + self .ssh_args + [self .ssh_dest , "mktemp -d" ]
249246
250247 result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
251248
@@ -288,8 +285,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
288285 mode = "r+b" if binary else "r+"
289286
290287 with tempfile .NamedTemporaryFile (mode = mode , delete = False ) as tmp_file :
288+ # Because in scp we set up port using -P option
289+ scp_args = ['-P' if x == '-p' else x for x in self .ssh_args ]
290+
291291 if not truncate :
292- scp_cmd = ['scp' ] + self . ssh_cmd + [f"{ self .ssh_dest } :{ filename } " , tmp_file .name ]
292+ scp_cmd = ['scp' ] + scp_args + [f"{ self .ssh_dest } :{ filename } " , tmp_file .name ]
293293 subprocess .run (scp_cmd , check = False ) # The file might not exist yet
294294 tmp_file .seek (0 , os .SEEK_END )
295295
@@ -305,11 +305,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
305305 tmp_file .write (data )
306306
307307 tmp_file .flush ()
308- scp_cmd = ['scp' ] + self . ssh_cmd + [tmp_file .name , f"{ self .ssh_dest } :{ filename } " ]
308+ scp_cmd = ['scp' ] + scp_args + [tmp_file .name , f"{ self .ssh_dest } :{ filename } " ]
309309 subprocess .run (scp_cmd , check = True )
310310
311311 remote_directory = os .path .dirname (filename )
312- mkdir_cmd = ['ssh' ] + self .ssh_cmd + [self .ssh_dest , f"mkdir -p { remote_directory } " ]
312+ mkdir_cmd = ['ssh' ] + self .ssh_args + [self .ssh_dest , f"mkdir -p { remote_directory } " ]
313313 subprocess .run (mkdir_cmd , check = True )
314314
315315 os .remove (tmp_file .name )
@@ -374,7 +374,7 @@ def get_pid(self):
374374 return int (self .exec_command ("echo $$" , encoding = get_default_encoding ()))
375375
376376 def get_process_children (self , pid ):
377- command = ["ssh" ] + self .ssh_cmd + [self .ssh_dest , f"pgrep -P { pid } " ]
377+ command = ["ssh" ] + self .ssh_args + [self .ssh_dest , f"pgrep -P { pid } " ]
378378
379379 result = subprocess .run (command , stdout = subprocess .PIPE , stderr = subprocess .PIPE , text = True )
380380
@@ -386,18 +386,11 @@ def get_process_children(self, pid):
386386
387387 # Database control
388388 def db_connect (self , dbname , user , password = None , host = "localhost" , port = 5432 ):
389- """
390- Established SSH tunnel and Connects to a PostgreSQL
391- """
392- self .establish_ssh_tunnel (local_port = port , remote_port = 5432 )
393- try :
394- conn = pglib .connect (
395- host = host ,
396- port = port ,
397- database = dbname ,
398- user = user ,
399- password = password ,
400- )
401- return conn
402- except Exception as e :
403- raise Exception (f"Could not connect to the database. Error: { e } " )
389+ conn = pglib .connect (
390+ host = host ,
391+ port = port ,
392+ database = dbname ,
393+ user = user ,
394+ password = password ,
395+ )
396+ return conn
0 commit comments