74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
from pathlib import Path
|
|
from threading import Thread, Lock
|
|
|
|
from tqdm import tqdm
|
|
|
|
from fabric2 import ThreadingGroup, SerialGroup, Config, Connection
|
|
|
|
import paramiko
|
|
|
|
|
|
def thread_run(connection, command, result_lock, result_queue):
|
|
try:
|
|
res = connection.run(command, warn=True, hide=True)
|
|
except (paramiko.ssh_exception.NoValidConnectionsError, paramiko.ssh_exception.SSHException) as inst:
|
|
res = f"Could not connect to host: {inst}"
|
|
with result_lock:
|
|
result_queue.append((connection, res))
|
|
|
|
# A set of hosts we can target with a series of commands and also can collect output of each command
|
|
class HostSet:
|
|
def __init__(self, hostlist: list, ssh_config_path: Path=None):
|
|
if ssh_config_path is None:
|
|
ssh_config_path = Path("~/.ssh/config")
|
|
ssh_config_path = ssh_config_path.expanduser()
|
|
config = Config({"ssh_config_path": str(ssh_config_path)})
|
|
self.connections = []
|
|
for host in hostlist:
|
|
self.connections.append(Connection(host, config=config))
|
|
|
|
def run(self, command: str):
|
|
resq = []
|
|
reslock = Lock()
|
|
threads = []
|
|
prog = tqdm(total=len(self.connections), unit="hosts")
|
|
for connection in self.connections:
|
|
t = Thread(target=thread_run, args=[connection, command, reslock, resq])
|
|
t.start()
|
|
threads.append(t)
|
|
|
|
# display status about threads, and join them when they finish
|
|
while (True):
|
|
nt = []
|
|
for i in threads:
|
|
if not i.is_alive():
|
|
i.join()
|
|
prog.update()
|
|
#print('.', end='', flush=True)
|
|
else:
|
|
nt.append(i)
|
|
threads = nt
|
|
if len(threads) == 0:
|
|
break
|
|
prog.close()
|
|
|
|
# Gather up results by output
|
|
gathered = {}
|
|
for connection, result in resq:
|
|
rstr = str(result)
|
|
if not rstr in gathered:
|
|
gathered[rstr] = []
|
|
gathered[rstr].append(connection)
|
|
|
|
# display results
|
|
for result, connections in gathered.items():
|
|
print('-----> [{}]'.format(' '.join(connection.original_host for connection in connections)))
|
|
print(result)
|
|
# ## gather_output
|
|
# for connection in self.connections:
|
|
# # import pdb
|
|
# # pdb.set_trace()
|
|
# res = connection.run(command)
|
|
# print('---- {} ----'.format(connection.original_host))
|
|
# print(res)
|