multiball/multiball/fabtools.py

68 lines
2.3 KiB
Python

from pathlib import Path
from threading import Thread, Lock
from tqdm import tqdm
from fabric2 import ThreadingGroup, SerialGroup, Config, Connection
def thread_run(connection, command, result_lock, result_queue):
res = connection.run(command, warn=True, hide=True)
with result_lock:
result_queue.append((connection, res))
# A set of hosts we can target with a series of commands and also can collect output of each command
class HostSet:
def __init__(self, hostlist: list, ssh_config_path: Path=None):
if ssh_config_path is None:
ssh_config_path = Path("~/.ssh/config")
ssh_config_path = ssh_config_path.expanduser()
config = Config({"ssh_config_path": str(ssh_config_path)})
self.connections = []
for host in hostlist:
self.connections.append(Connection(host, config=config))
def run(self, command: str):
resq = []
reslock = Lock()
threads = []
prog = tqdm(total=len(self.connections), unit="hosts")
for connection in self.connections:
t = Thread(target=thread_run, args=[connection, command, reslock, resq])
t.start()
threads.append(t)
# display status about threads, and join them when they finish
while (True):
nt = []
for i in threads:
if not i.is_alive():
i.join()
prog.update()
#print('.', end='', flush=True)
else:
nt.append(i)
threads = nt
if len(threads) == 0:
break
prog.close()
# Gather up results by output
gathered = {}
for connection, result in resq:
rstr = str(result)
if not rstr in gathered:
gathered[rstr] = []
gathered[rstr].append(connection)
# display results
for result, connections in gathered.items():
print('-----> [{}]'.format(' '.join(connection.original_host for connection in connections)))
print(result)
# ## gather_output
# for connection in self.connections:
# # import pdb
# # pdb.set_trace()
# res = connection.run(command)
# print('---- {} ----'.format(connection.original_host))
# print(res)