import subprocess
import re
import json
import sys
#import pprint

from flask import current_app
from time import sleep
from os.path import join
from subprocess import run

from capsulflask.db import get_model

from capsulflask.shared import VirtualizationInterface, VirtualMachine, validate_capsul_id, my_exec_info_message


class MockSpoke(VirtualizationInterface):

  def __init__(self):
    self.capsuls = dict()

  def capacity_avaliable(self, additional_ram_bytes):
    return True

  def get(self, id, get_ssh_host_keys):
    validate_capsul_id(id)

    ipv4 = "1.1.1.1"
    if id in self.capsuls:
      ipv4 = self.capsuls[id]['public_ipv4']

    if get_ssh_host_keys:
      ssh_host_keys = json.loads("""[
        {"key_type":"ED25519", "content":"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIN8cna0zeKSKl/r8whdn/KmDWhdzuWRVV0GaKIM+eshh", "sha256":"V4X2apAF6btGAfS45gmpldknoDX0ipJ5c6DLfZR2ttQ"},
        {"key_type":"RSA", "content":"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCvotgzgEP65JUQ8S8OoNKy1uEEPEAcFetSp7QpONe6hj4wPgyFNgVtdoWdNcU19dX3hpdse0G8OlaMUTnNVuRlbIZXuifXQ2jTtCFUA2mmJ5bF+XjGm3TXKMNGh9PN+wEPUeWd14vZL+QPUMev5LmA8cawPiU5+vVMLid93HRBj118aCJFQxLgrdP48VPfKHFRfCR6TIjg1ii3dH4acdJAvlmJ3GFB6ICT42EmBqskz2MPe0rIFxH8YohCBbAbrbWYcptHt4e48h4UdpZdYOhEdv89GrT8BF2C5cbQ5i9qVpI57bXKrj8hPZU5of48UHLSpXG8mbH0YDiOQOfKX/Mt", "sha256":"ghee6KzRnBJhND2kEUZSaouk7CD6o6z2aAc8GPkV+GQ"},
        {"key_type":"ECDSA", "content":"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBLLgOoATz9R4aS2kk7vWoxX+lshK63t9+5BIHdzZeFE1o+shlcf0Wji8cN/L1+m3bi0uSETZDOAWMP3rHLJj9Hk=", "sha256":"aCYG1aD8cv/TjzJL0bi9jdabMGksdkfa7R8dCGm1yYs"}
      ]""")
      return VirtualMachine(id, current_app.config["SPOKE_HOST_ID"], ipv4=ipv4, state="running", ssh_host_keys=ssh_host_keys)

    return VirtualMachine(id, current_app.config["SPOKE_HOST_ID"], ipv4=ipv4, state="running")

  def get_all_by_id(self) -> dict:
    by_host_and_network = get_model().non_deleted_vms_by_host_and_network(current_app.config["SPOKE_HOST_ID"])
    to_return = dict()
    for host in by_host_and_network.values():
      for network in host.values():
        for vm in network:
          vm['state'] = vm['desired_state']
          to_return[vm['id']] = vm

    current_app.logger.info(f"MOCK get_all_by_id: {json.dumps(to_return)}")
    return to_return

  def create(self, email: str, id: str, template_image_file_name: str, vcpus: int, memory_mb: int, ssh_authorized_keys: list, network_name: str, public_ipv4: str):
    validate_capsul_id(id)
    current_app.logger.info(f"mock create: {id} for {email}")
    self.capsuls[id] = dict(email=email, id=id, network_name=network_name, public_ipv4=public_ipv4)
    sleep(1)

  def destroy(self, email: str, id: str):
    current_app.logger.info(f"mock destroy: {id} for {email}")

  def vm_state_command(self, email: str, id: str, command: str):
    current_app.logger.info(f"mock {command}: {id} for {email}")

  def net_set_dhcp(self, email: str, host_id: str, network_name: str, macs: list, remove_ipv4: str, add_ipv4: str):
    current_app.logger.info(f"mock net_set_dhcp: host_id={host_id} network_name={network_name} macs={','.join(macs)} remove_ipv4={remove_ipv4} add_ipv4={add_ipv4} for {email}")


class ShellScriptSpoke(VirtualizationInterface):

  def validate_completed_process(self, completedProcess, email=None):
    emailPart = ""
    if email != None:
      emailPart = f"for {email}"

    if completedProcess.returncode != 0:
      raise RuntimeError(f"""{" ".join(completedProcess.args)} failed {emailPart} with exit code {completedProcess.returncode}
        stdout:
        {completedProcess.stdout}
        stderr:
        {completedProcess.stderr}
      """)

  def capacity_avaliable(self, additional_ram_bytes):
    my_args=[join(current_app.root_path, 'shell_scripts/capacity-avaliable.sh'), str(additional_ram_bytes)]
    completedProcess = run(my_args, capture_output=True)

    if completedProcess.returncode != 0:
      current_app.logger.error(f"""
      capacity-avaliable.sh exited {completedProcess.returncode} with
        stdout:
        {completedProcess.stdout}
        stderr:
        {completedProcess.stderr}
      """)
      return False

    lines = completedProcess.stdout.splitlines()
    output = lines[len(lines)-1]
    if not output == b"yes":
      current_app.logger.error(f"capacity-avaliable.sh exited 0 and returned {output} but did not return \"yes\" ")
      return False

    return True

  def get(self, id, get_ssh_host_keys):
    validate_capsul_id(id)
    completedProcess = run([join(current_app.root_path, 'shell_scripts/get.sh'), id], capture_output=True)
    self.validate_completed_process(completedProcess)
    lines = completedProcess.stdout.splitlines()
    if len(lines) == 0:
      current_app.logger.warning("shell_scripts/get.sh returned zero lines!")
      return None
    
    result_string = lines[0].decode("utf-8")

    fields = result_string.split(" ")
    if fields[0] != "true":
      current_app.logger.warning(f"shell_scripts/get.sh was called for {id} which libvirt says does not exist.")
      return None

    if len(fields) < 2:
      return VirtualMachine(id, current_app.config["SPOKE_HOST_ID"])

    state = fields[1]

    if len(fields) < 3:
      return VirtualMachine(id, current_app.config["SPOKE_HOST_ID"], state=state)

    ipaddr = fields[2]

    if not re.match(r"^([0-9]{1,3}\.){3}[0-9]{1,3}$", ipaddr):
      return VirtualMachine(id, current_app.config["SPOKE_HOST_ID"], state=state)

    if get_ssh_host_keys:
      try:
        completedProcess2 = run([join(current_app.root_path, 'shell_scripts/ssh-keyscan.sh'), ipaddr], capture_output=True)
        self.validate_completed_process(completedProcess2)
        ssh_host_keys = json.loads(completedProcess2.stdout.decode("utf-8"))
        return VirtualMachine(id, current_app.config["SPOKE_HOST_ID"], state=state, ipv4=ipaddr, ssh_host_keys=ssh_host_keys)
      except:
        current_app.logger.warning(f"""
          failed to ssh-keyscan {id} at {ipaddr}:
          {my_exec_info_message(sys.exc_info())}"""
        )

    return VirtualMachine(id, current_app.config["SPOKE_HOST_ID"], state=state, ipv4=ipaddr)

  def get_all_by_id(self) -> dict:

    vm_list_process = run([join(current_app.root_path, 'shell_scripts/virsh-list.sh')], capture_output=True)
    self.validate_completed_process(vm_list_process)
    vms_json_string = vm_list_process.stdout.decode("utf-8")
    #current_app.logger.info(f"vms_json_string: {vms_json_string}")
    list_of_vms = json.loads(vms_json_string)

    #current_app.logger.info(f"list_of_vms: {json.dumps(list_of_vms)}")

    vms_by_id = dict()
    for vm in list_of_vms:
      vms_by_id[vm['id']] = dict(id=vm['id'], macs=dict(), state=vm['state'])

    net_list_process = run([join(current_app.root_path, 'shell_scripts/virsh-net-list.sh')], capture_output=True)
    self.validate_completed_process(net_list_process)
    net_list_json_string = net_list_process.stdout.decode("utf-8")
    #current_app.logger.info(f"net_list_json_string: {net_list_json_string}")
    list_of_networks = json.loads(net_list_json_string)

    #current_app.logger.info(f"list_of_networks: {json.dumps(list_of_networks)}")

    vm_id_by_mac = dict()
    for network in list_of_networks:

      macs_json_filename = f"{current_app.config['LIBVIRT_DNSMASQ_PATH']}/{network['virtual_bridge_name']}.macs"
      with open(macs_json_filename, mode='r') as macs_json_file:
        vms_with_macs = []
        try:
          vms_with_macs = json.load(macs_json_file)
        except:
          current_app.logger.warn(f"failed to parse the JSON file '{macs_json_filename}'")

        for vm in vms_with_macs:
          for mac in vm['macs']:
            if mac not in vm_id_by_mac:
              vm_id_by_mac[mac] = vm['domain']
            else:
              raise Exception(f"the mac address '{mac}' is used by both '{vm_id_by_mac[mac]}' and '{vm['domain']}'")

            if vm['domain'] not in vms_by_id:
              current_app.logger.warn(f"'{vm['domain']}' was in dnsmask but not in libvirt, defaulting to 'shut off' state")

              vms_by_id[vm['domain']] = dict(id=vm['domain'], macs=dict(), state="shut off")
            
            vms_by_id[vm['domain']]['network_name'] = network['network_name']
            vms_by_id[vm['domain']]['macs'][mac] = True
      
      status_json_filename = f"{current_app.config['LIBVIRT_DNSMASQ_PATH']}/{network['virtual_bridge_name']}.status"
      with open(status_json_filename, mode='r') as status_json_file:
        statuses = []
        try:
          statuses = json.load(status_json_file)
        except:
          current_app.logger.warn(f"failed to parse the JSON file '{status_json_filename}'")
        
        for status in statuses:
          if status['mac-address'] in vm_id_by_mac:
            vm_id = vm_id_by_mac[status['mac-address']]
            vms_by_id[vm_id]['public_ipv4'] = status['ip-address']
          else:
            current_app.logger.warn(f"get_all_by_id: {status['mac-address']} not in vm_id_by_mac")

    #current_app.logger.info(f"\n*******************3:\n{pprint.pformat(vms_by_id)}\n\n\n\n")

    return vms_by_id
    


  def create(self, email: str, id: str, template_image_file_name: str, vcpus: int, memory_mb: int, ssh_authorized_keys: list, network_name: str, public_ipv4: str):
    validate_capsul_id(id)

    if not re.match(r"^[a-zA-Z0-9/_.-]+$", template_image_file_name):
      raise ValueError(f"template_image_file_name \"{template_image_file_name}\" must match \"^[a-zA-Z0-9/_.-]+$\"")

    for ssh_authorized_key in ssh_authorized_keys:
      if not re.match(r"^(ssh|ecdsa)-[0-9A-Za-z+/_=@:. -]+$", ssh_authorized_key):
        raise ValueError(f"ssh_authorized_key \"{ssh_authorized_key}\" must match \"^(ssh|ecdsa)-[0-9A-Za-z+/_=@:. -]+$\"")

    if isinstance(vcpus, int) and (vcpus < 1 or vcpus > 8):
      raise ValueError(f"vcpus \"{vcpus}\" must match 1 <= vcpus <= 8")

    if isinstance(memory_mb, int) and (memory_mb < 512 or memory_mb > 16384):
      raise ValueError(f"memory_mb \"{memory_mb}\" must match 512 <= memory_mb <= 16384")

    if not re.match(r"^[a-zA-Z0-9_-]+$", network_name):
      raise ValueError(f"network_name \"{network_name}\" must match \"^[a-zA-Z0-9_-]+\"")

    # if not re.match(r"^[0-9.]+$", public_ipv4):
    #   raise ValueError(f"public_ipv4 \"{public_ipv4}\" must match \"^[0-9.]+$\"")

    ssh_keys_string = "\n".join(ssh_authorized_keys)

    completedProcess = run([
      join(current_app.root_path, 'shell_scripts/create.sh'),
      id,
      template_image_file_name,
      str(vcpus),
      str(memory_mb),
      ssh_keys_string,
      network_name,
      public_ipv4
    ], capture_output=True)

    self.validate_completed_process(completedProcess, email)
    lines = completedProcess.stdout.splitlines()
    status = lines[len(lines)-1].decode("utf-8")

    vmSettings = f"""
      id={id}
      template_image_file_name={template_image_file_name}
      vcpus={str(vcpus)}
      memory={str(memory_mb)}
      ssh_authorized_keys={ssh_keys_string}
      network_name={network_name}
      public_ipv4={public_ipv4}
    """

    if not status == "success":
      raise ValueError(f"""failed to create vm for {email} with:
        {vmSettings}
        stdout:
        {completedProcess.stdout}
        stderr:
        {completedProcess.stderr}
      """)

  def destroy(self, email: str, id: str):
    validate_capsul_id(id)
    completedProcess = run([join(current_app.root_path, 'shell_scripts/destroy.sh'), id], capture_output=True)
    self.validate_completed_process(completedProcess, email)
    lines = completedProcess.stdout.splitlines()
    status = lines[len(lines)-1].decode("utf-8")
    if not status == "success":
      raise ValueError(f"""failed to destroy vm {id} for {email} on {current_app.config["SPOKE_HOST_ID"]}:
        stdout:
        {completedProcess.stdout}
        stderr:
        {completedProcess.stderr}
      """)

  def vm_state_command(self, email: str, id: str, command: str):
    validate_capsul_id(id)
    if command not in ["stop", "force-stop", "start", "restart"]:
      raise ValueError(f"command ({command}) must be one of stop, force-stop, start, or restart")

    completedProcess = run([join(current_app.root_path, f"shell_scripts/{command}.sh"), id], capture_output=True)
    self.validate_completed_process(completedProcess, email)
    returned_string = completedProcess.stdout.decode("utf-8")
    current_app.logger.info(f"{command} vm {id} for {email} returned: {returned_string}")
  
  def net_set_dhcp(self, email: str, host_id: str, network_name: str, macs: list, remove_ipv4: str, add_ipv4: str):

    if not re.match(r"^[a-zA-Z0-9_-]+$", network_name):
      raise ValueError(f"network_name \"{network_name}\" must match \"^[a-zA-Z0-9_-]+\"")

    if not isinstance(macs, list):
      raise ValueError(f"macs must be a list")

    for mac in macs:
      if not re.match(r"^[0-9a-f:]+$", mac):
        raise ValueError(f"mac \"{mac}\" must match \"^[0-9a-f:]+$\"")

    if remove_ipv4 != None and remove_ipv4 != "":
      if not re.match(r"^[0-9.]+$", remove_ipv4):
        raise ValueError(f"remove_ipv4 \"{remove_ipv4}\" must match \"^[0-9.]+$\"")
      
      for mac in macs:
        completedProcess = run([join(current_app.root_path, f"shell_scripts/ip-dhcp-host.sh"), "delete", network_name, mac, remove_ipv4], capture_output=True)
        self.validate_completed_process(completedProcess, email)

    if add_ipv4 != None and add_ipv4 != "":
      if not re.match(r"^[0-9.]+$", add_ipv4):
        raise ValueError(f"add_ipv4 \"{add_ipv4}\" must match \"^[0-9.]+$\"")
      
      for mac in macs:
        completedProcess = run([join(current_app.root_path, f"shell_scripts/ip-dhcp-host.sh"), "add", network_name, mac, add_ipv4], capture_output=True)
        self.validate_completed_process(completedProcess, email)