import re
import sys
import json
from datetime import datetime, timedelta
from flask import Blueprint
from flask import flash
from flask import current_app
from flask import g
from flask import request
from flask import session
from flask import render_template
from flask import redirect
from flask import url_for
from werkzeug.exceptions import abort
from nanoid import generate

from capsulflask.metrics import durations as metric_durations
from capsulflask.auth import account_required
from capsulflask.db import get_model
from capsulflask.shared import my_exec_info_message
from capsulflask.payment import poll_btcpay_session
from capsulflask import cli

bp = Blueprint("console", __name__, url_prefix="/console")

def makeCapsulId():
  lettersAndNumbers = generate(alphabet="1234567890qwertyuiopasdfghjklzxcvbnm", size=10)
  return f"capsul-{lettersAndNumbers}"

def double_check_capsul_address(id, ipv4, get_ssh_host_keys):
  try:
    result = current_app.config["HUB_MODEL"].get(id, get_ssh_host_keys)
    if result.ipv4 != ipv4:
      ipv4 = result.ipv4
      get_model().update_vm_ip(email=session["account"], id=id, ipv4=result.ipv4)
    if get_ssh_host_keys:
      get_model().update_vm_ssh_host_keys(email=session["account"], id=id, ssh_host_keys=result.ssh_host_keys)
  except:
    current_app.logger.error(f"""
      the virtualization model threw an error in double_check_capsul_address of {id}:
      {my_exec_info_message(sys.exc_info())}"""
    )
    return None

  return result

@bp.route("/")
@account_required
def index():
  vms = get_vms()
  vms = list(filter(lambda x: not x['deleted'], vms))
  created = request.args.get('created')

  # this is here to prevent xss
  if created and not re.match(r"^(cvm|capsul)-[a-z0-9]{10}$", created):
    created = '___________'

  # for now we are going to check the IP according to the virt model
  # on every request. this could be done by a background job and cached later on...
  for vm in vms:
    result = double_check_capsul_address(vm["id"], vm["ipv4"], False)
    if result is not None:
      vm["ipv4"] = result.ipv4

  vms = list(map(
    lambda x: dict(
      id=x['id'], 
      size=x['size'], 
      ipv4=(x['ipv4'] if x['ipv4'] else "..booting.."), 
      ipv4_status=("ok" if x['ipv4'] else "waiting-pulse"), 
      os=x['os'], 
      created=x['created'].strftime("%b %d %Y")
    ), vms
  ))

  return render_template("capsuls.html", vms=vms, has_vms=len(vms) > 0, created=created)

@bp.route("/<string:id>", methods=("GET", "POST"))
@account_required
def detail(id):

  duration=request.args.get('duration')
  if not duration:
    duration = "5m"

  vm = get_model().get_vm_detail(email=session["account"], id=id)

  if vm is None:
    return abort(404, f"{id} doesn't exist.")

  if vm['deleted']:
    return render_template("capsul-detail.html", vm=vm, delete=True, deleted=True)

  if request.method == "POST":
    if "csrf-token" not in request.form or request.form['csrf-token'] != session['csrf-token']:
      return abort(418, f"u want tea")

    if 'are_you_sure' not in request.form or not request.form['are_you_sure']:
      return render_template(
        "capsul-detail.html", 
        csrf_token = session["csrf-token"],
        vm=vm, 
        delete=True, 
        deleted=False
      )
    else:
      current_app.logger.info(f"deleting {vm['id']} per user request ({session['account']})")
      current_app.config["HUB_MODEL"].destroy(email=session['account'], id=id)
      get_model().delete_vm(email=session['account'], id=id)

      return render_template("capsul-detail.html", vm=vm, delete=True, deleted=True)

  else:
    needs_ssh_host_keys = "ssh_host_keys" not in vm or len(vm["ssh_host_keys"]) == 0

    vm_from_virt_model = double_check_capsul_address(vm["id"], vm["ipv4"], needs_ssh_host_keys)
    
    if vm_from_virt_model is not None:
      vm["ipv4"] = vm_from_virt_model.ipv4
      if needs_ssh_host_keys:
        vm["ssh_host_keys"] = vm_from_virt_model.ssh_host_keys
    
    vm["created"] = vm['created'].strftime("%b %d %Y %H:%M")
    vm["ssh_authorized_keys"] = ", ".join(vm["ssh_authorized_keys"]) if len(vm["ssh_authorized_keys"]) > 0 else "<missing>"

    return render_template(
      "capsul-detail.html", 
      csrf_token = session["csrf-token"],
      vm=vm, 
      delete=False,
      durations=list(map(lambda x: x.strip("_"), metric_durations.keys())),
      duration=duration
    )


@bp.route("/create", methods=("GET", "POST"))
@account_required
def create():
  vm_sizes = get_model().vm_sizes_dict()
  operating_systems = get_model().operating_systems_dict()
  public_keys_for_account = get_model().list_ssh_public_keys_for_account(session["account"])
  account_balance = get_account_balance(get_vms(), get_payments(), datetime.utcnow())
  capacity_avaliable = current_app.config["HUB_MODEL"].capacity_avaliable(512*1024*1024)
  errors = list()

  if request.method == "POST":
    if "csrf-token" not in request.form or request.form['csrf-token'] != session['csrf-token']:
      return abort(418, f"u want tea")

    size = request.form["size"]
    os = request.form["os"]
    if not size:
      errors.append("Size is required")
    elif size not in vm_sizes:
      errors.append(f"Invalid size {size}")

    if not os:
      errors.append("OS is required")
    elif os not in operating_systems:
      errors.append(f"Invalid os {os}")

    posted_keys_count = int(request.form["ssh_public_key_count"])
    posted_keys = list()

    if posted_keys_count > 1000:
      errors.append("something went wrong with ssh keys")
    else:
      for i in range(0, posted_keys_count):
        if f"ssh_key_{i}" in request.form:
          posted_name = request.form[f"ssh_key_{i}"]
          key = None
          for x in public_keys_for_account:
            if x['name'] == posted_name:
              key = x
          if key:
            posted_keys.append(key)
          else:
            errors.append(f"SSH Key \"{posted_name}\" doesn't exist")

    if len(posted_keys) == 0:
      errors.append("At least one SSH Public Key is required")

    capacity_avaliable = current_app.config["HUB_MODEL"].capacity_avaliable(vm_sizes[size]['memory_mb']*1024*1024)

    if not capacity_avaliable:
      errors.append("""
        host(s) at capacity. no capsuls can be created at this time. sorry. 
      """)

    if len(errors) == 0:
      id = makeCapsulId()
      get_model().create_vm(
        email=session["account"], 
        id=id, 
        size=size, 
        os=os,
        ssh_authorized_keys=list(map(lambda x: x["name"], posted_keys))
      )
      current_app.config["HUB_MODEL"].create(
        email = session["account"],
        id=id,
        template_image_file_name=operating_systems[os]['template_image_file_name'],
        vcpus=vm_sizes[size]['vcpus'],
        memory_mb=vm_sizes[size]['memory_mb'],
        ssh_authorized_keys=list(map(lambda x: x["content"], posted_keys))
      )
      
      return redirect(f"{url_for('console.index')}?created={id}")
  
  affordable_vm_sizes = dict()
  for key, vm_size in vm_sizes.items():
    # if a user deposits $7.50 and then creates an f1-s vm which costs 7.50 a month, 
    # then they have to delete the vm and re-create it, they will not be able to, they will have to pay again.
    # so for UX it makes a lot of sense to give a small margin of 25 cents for usability sake
    if vm_size["dollars_per_month"] <= account_balance+0.25:
      affordable_vm_sizes[key] = vm_size

  for error in errors:
    flash(error)

  if not capacity_avaliable:
    current_app.logger.warning(f"when capsul capacity is restored, send an email to {session['account']}")

  return render_template(
    "create-capsul.html",
    csrf_token = session["csrf-token"],
    capacity_avaliable=capacity_avaliable,
    account_balance=format(account_balance, '.2f'),
    ssh_public_keys=public_keys_for_account,
    ssh_public_key_count=len(public_keys_for_account),
    no_ssh_public_keys=len(public_keys_for_account) == 0,
    operating_systems=operating_systems,
    cant_afford=len(affordable_vm_sizes) == 0,
    vm_sizes=affordable_vm_sizes
  )

@bp.route("/ssh", methods=("GET", "POST"))
@account_required
def ssh_public_keys():
  errors = list()

  if request.method == "POST":
    if "csrf-token" not in request.form or request.form['csrf-token'] != session['csrf-token']:
      return abort(418, f"u want tea")
      
    method = request.form["method"]
    content = None

    name = request.form["name"]
    if not name or len(name.strip()) < 1:
      if method == "POST":
        parts = re.split(" +", request.form["content"])
        if len(parts) > 2 and len(parts[2].strip()) > 0:
          name = parts[2].strip()
        else:
          name = parts[0].strip()
      else:
        errors.append("Name is required")
    if not re.match(r"^[0-9A-Za-z_@. -]+$", name):
      errors.append("Name must match \"^[0-9A-Za-z_@. -]+$\"")

    if method == "POST":
      content = request.form["content"]
      if not content or len(content.strip()) < 1:
        errors.append("Content is required")
      else:
        content = content.replace("\r", "").replace("\n", "")
        if not re.match(r"^(ssh|ecdsa)-[0-9A-Za-z+/_=@. -]+$", content):
          errors.append("Content must match \"^(ssh|ecdsa)-[0-9A-Za-z+/_=@. -]+$\"")

      if get_model().ssh_public_key_name_exists(session["account"], name):
        errors.append("A key with that name already exists")

      if len(errors) == 0:
        get_model().create_ssh_public_key(session["account"], name, content)

    elif method == "DELETE":

      if len(errors) == 0:
        get_model().delete_ssh_public_key(session["account"], name)

  for error in errors:
    flash(error)

  keys_list=list(map(
    lambda x: dict(name=x['name'], content=f"{x['content'][:20]}...{x['content'][len(x['content'])-20:]}"), 
    get_model().list_ssh_public_keys_for_account(session["account"])
  ))

  return render_template(
    "ssh-public-keys.html", 
    csrf_token = session["csrf-token"],
    ssh_public_keys=keys_list, 
    has_ssh_public_keys=len(keys_list) > 0
  )

def get_vms():
  if 'user_vms' not in g:
    g.user_vms = get_model().list_vms_for_account(session["account"])
  return g.user_vms

def get_payments():
  if 'user_payments' not in g:
    g.user_payments = get_model().list_payments_for_account(session["account"])
  return g.user_payments


average_number_of_days_in_a_month = 30.44

def get_vm_months_float(vm, as_of):
  end_datetime = vm["deleted"] if vm["deleted"] else as_of
  days = float((end_datetime - vm["created"]).total_seconds())/float(60*60*24)
  if days < 1:
    days = float(1)
  return days / average_number_of_days_in_a_month

def get_account_balance(vms, payments, as_of):

  vm_cost_dollars = 0.0
  for vm in vms:
    vm_months = get_vm_months_float(vm, as_of)
    vm_cost_dollars += vm_months * float(vm["dollars_per_month"])

  payment_dollars_total = float( sum(map(lambda x: 0 if x["invalidated"] else x["dollars"], payments)) )

  return payment_dollars_total - vm_cost_dollars

@bp.route("/account-balance")
@account_required
def account_balance():

  payment_sessions = get_model().list_payment_sessions_for_account(session['account'])
  for payment_session in payment_sessions:
    if payment_session['type'] == 'btcpay':
      poll_btcpay_session(payment_session['id'])

  payments = get_payments()
  vms = get_vms()
  balance_1w = get_account_balance(vms, payments, datetime.utcnow() + timedelta(days=7)) 
  balance_1d = get_account_balance(vms, payments, datetime.utcnow() + timedelta(days=1)) 
  balance_now = get_account_balance(vms, payments, datetime.utcnow())

  warning_index = -1
  warning_text = ""
  warnings = cli.get_warnings_list()

  for i in range(0, len(warnings)):
    if warnings[i]['get_active'](balance_1w, balance_1d, balance_now):
      warning_index = i
  if warning_index > -1:
    pluralize_capsul = "s" if len(vms) > 1 else ""
    warning_id = warnings[warning_index]['id']
    warning_text = cli.get_warning_headline(warning_id, pluralize_capsul)

  vms_billed = list()

  for vm in get_vms():
    vm_months = get_vm_months_float(vm, datetime.utcnow())
    vms_billed.append(dict(
      id=vm["id"], 
      dollars_per_month=vm["dollars_per_month"],
      created=vm["created"].strftime("%b %d %Y"),
      deleted=vm["deleted"].strftime("%b %d %Y") if vm["deleted"] else "N/A",
      months=format(vm_months, '.3f'),
      dollars=format(vm_months * float(vm["dollars_per_month"]), '.2f')
    ))

  return render_template(
    "account-balance.html", 
    has_vms=len(vms_billed)>0, 
    vms_billed=vms_billed,
    warning_text=warning_text,
    payments=list(map(
      lambda x: dict(
        dollars=x["dollars"], 
        class_name="invalidated" if x["invalidated"] else "", 
        created=x["created"].strftime("%b %d %Y")
      ), 
      payments
    )), 
    has_payments=len(payments)>0, 
    account_balance=format(balance_now, '.2f')
  )