# I was never able to get this type hinting to work correctly 
# from psycopg2.extensions import connection as Psycopg2Connection, cursor as Psycopg2Cursor
from nanoid import generate
from flask import current_app
from typing import List

from capsulflask.shared import OnlineHost


class DBModel:
  #def __init__(self, connection: Psycopg2Connection, cursor: Psycopg2Cursor):
  def __init__(self, connection, cursor):
    self.connection = connection
    self.cursor = cursor
    self.cursor.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")


  #     ------    LOGIN     ---------     


  def login(self, email):
    self.cursor.execute("SELECT * FROM accounts WHERE email = %s", (email, ))
    hasExactMatch = len(self.cursor.fetchall())
    self.cursor.execute("SELECT * FROM accounts WHERE email = %s AND ever_logged_in = TRUE", (email, ))
    everLoggedIn = len(self.cursor.fetchall())
    ignoreCaseMatches = []
    if everLoggedIn == 0:
      self.cursor.execute("SELECT email FROM accounts WHERE lower_case_email = %s AND email != %s", (email.lower(), email))
      ignoreCaseMatches = list(map(lambda x: x[0], self.cursor.fetchall()))
  
    if hasExactMatch == 0:
      self.cursor.execute("INSERT INTO accounts (email, lower_case_email) VALUES (%s, %s)", (email, email.lower()))

    self.cursor.execute("SELECT token FROM login_tokens WHERE email = %s and created > (NOW() - INTERVAL '20 min')", (email, ))
    if len(self.cursor.fetchall()) > 2:
      return (None, ignoreCaseMatches)

    token = generate()
    self.cursor.execute("INSERT INTO login_tokens (email, token) VALUES (%s, %s)", (email, token))
    self.connection.commit()

    return (token, ignoreCaseMatches)
    
  def consume_token(self, token):
    self.cursor.execute("SELECT email FROM login_tokens WHERE token = %s and created > (NOW() - INTERVAL '20 min')", (token, ))
    row = self.cursor.fetchone()
    if row:
      email = row[0]
      self.cursor.execute("DELETE FROM login_tokens WHERE email = %s", (email, ))
      self.cursor.execute("UPDATE accounts SET ever_logged_in = TRUE WHERE email = %s", (email, ))
      self.connection.commit()
      return email
    return None


  #     ------    VM & ACCOUNT MANAGEMENT     ---------  


  def all_non_deleted_vm_ids(self,):
    self.cursor.execute("SELECT id FROM vms WHERE deleted IS NULL")
    return list(map(lambda x: x[0], self.cursor.fetchall()))

  def operating_systems_dict(self,):
    self.cursor.execute("SELECT id, template_image_file_name, description FROM os_images WHERE deprecated = FALSE")

    operatingSystems = dict()
    for row in self.cursor.fetchall():
      operatingSystems[row[0]] = dict(template_image_file_name=row[1], description=row[2])

    return operatingSystems

  def vm_sizes_dict(self,):
    self.cursor.execute("SELECT id, dollars_per_month, vcpus, memory_mb, bandwidth_gb_per_month FROM vm_sizes")

    vmSizes = dict()
    for row in self.cursor.fetchall():
      vmSizes[row[0]] = dict(dollars_per_month=row[1], vcpus=row[2], memory_mb=row[3], bandwidth_gb_per_month=row[4])

    return vmSizes

  def list_ssh_public_keys_for_account(self, email):
    self.cursor.execute("SELECT name, content, created FROM ssh_public_keys WHERE email = %s", (email, ))
    return list(map(
      lambda x: dict(name=x[0], content=x[1], created=x[2]), 
      self.cursor.fetchall()
    ))

  def ssh_public_key_name_exists(self, email, name):
    self.cursor.execute( "SELECT name FROM ssh_public_keys where email = %s AND name = %s", (email, name) )
    return len(self.cursor.fetchall()) > 0

  def create_ssh_public_key(self, email, name, content):
    self.cursor.execute(""" 
      INSERT INTO ssh_public_keys (email, name, content)
      VALUES  (%s, %s, %s)
      """, 
      (email, name, content)
    )
    self.connection.commit()

  def delete_ssh_public_key(self, email, name):
    self.cursor.execute( "DELETE FROM ssh_public_keys where email = %s AND name = %s", (email, name) )
    self.connection.commit()

  def list_vms_for_account(self, email):
    self.cursor.execute(""" 
      SELECT vms.id, vms.last_seen_ipv4, vms.last_seen_ipv6, vms.size, vms.os, vms.created, vms.deleted, vm_sizes.dollars_per_month
        FROM vms JOIN vm_sizes on vms.size = vm_sizes.id
      WHERE vms.email = %s""", 
      (email, )
    )
    return list(map(
      lambda x: dict(id=x[0], ipv4=x[1], ipv6=x[2], size=x[3], os=x[4], created=x[5], deleted=x[6], dollars_per_month=x[7]), 
      self.cursor.fetchall()
    ))

  def update_vm_ip(self, email, id, ipv4):
    self.cursor.execute("UPDATE vms SET last_seen_ipv4 = %s WHERE email = %s AND id = %s", (ipv4, email, id))
    self.connection.commit()

  def update_vm_ssh_host_keys(self, email, id, ssh_host_keys):
    for key in ssh_host_keys:
      self.cursor.execute(""" 
        INSERT INTO vm_ssh_host_key (email, vm_id, key_type, content, sha256)
        VALUES  (%s, %s, %s, %s, %s)
        """, 
        (email, id, key['key_type'], key['content'], key['sha256'])
      )
    self.connection.commit()

  def create_vm(self, email, id, size, os, ssh_authorized_keys):
    self.cursor.execute(""" 
      INSERT INTO vms (email, id, size, os)
      VALUES  (%s, %s, %s, %s)
      """, 
      (email, id, size, os)
    )
    
    for ssh_authorized_key in ssh_authorized_keys:
      self.cursor.execute(""" 
        INSERT INTO vm_ssh_authorized_key (email, vm_id, ssh_public_key_name)
        VALUES  (%s, %s, %s)
        """, 
        (email, id, ssh_authorized_key)
      )
    self.connection.commit()

  def delete_vm(self, email, id):
    self.cursor.execute("UPDATE vms SET deleted = now() WHERE email = %s AND id = %s", ( email, id))
    self.connection.commit()

  def get_vm_detail(self, email, id):
    self.cursor.execute(""" 
      SELECT vms.id, vms.last_seen_ipv4, vms.last_seen_ipv6, os_images.description, vms.created, vms.deleted,
            vm_sizes.id, vm_sizes.dollars_per_month, vm_sizes.vcpus, vm_sizes.memory_mb, vm_sizes.bandwidth_gb_per_month
      FROM vms 
        JOIN os_images on vms.os = os_images.id
        JOIN vm_sizes on vms.size = vm_sizes.id
      WHERE vms.email = %s AND vms.id = %s""", 
      (email, id)
    )
    row = self.cursor.fetchone()
    if not row:
      return None

    vm = dict(
      id=row[0], ipv4=row[1], ipv6=row[2], os_description=row[3], created=row[4], deleted=row[5],
      size=row[6], dollars_per_month=row[7], vcpus=row[8], memory_mb=row[9], bandwidth_gb_per_month=row[10]
    )

    self.cursor.execute(""" 
      SELECT ssh_public_key_name FROM vm_ssh_authorized_key 
      WHERE vm_ssh_authorized_key.email = %s AND vm_ssh_authorized_key.vm_id = %s""", 
      (email, id)
    )
    vm["ssh_authorized_keys"] = list(map( lambda x: x[0], self.cursor.fetchall() ))


    self.cursor.execute(""" 
      SELECT key_type, content, sha256 FROM vm_ssh_host_key 
      WHERE vm_ssh_host_key.email = %s AND vm_ssh_host_key.vm_id = %s""", 
      (email, id)
    )
    vm["ssh_host_keys"] = list(map( lambda x: dict(key_type=x[0], content=x[1], sha256=x[2]), self.cursor.fetchall() ))

    return vm


  #     ------    PAYMENTS & ACCOUNT BALANCE     ---------  


  def list_payments_for_account(self, email):
    self.cursor.execute(""" 
      SELECT id, dollars, invalidated, created
      FROM payments WHERE payments.email = %s""", 
      (email, )
    )
    return list(map(
      lambda x: dict(id=x[0], dollars=x[1], invalidated=x[2], created=x[3]), 
      self.cursor.fetchall()
    ))
  
  def create_payment_session(self, payment_type, id, email, dollars):
    self.cursor.execute(""" 
      INSERT INTO payment_sessions (id, type, email, dollars)
      VALUES  (%s, %s, %s, %s)
      """, 
      (id, payment_type, email, dollars)
    )
    self.connection.commit()

  def list_payment_sessions_for_account(self, email):
    self.cursor.execute(""" 
      SELECT id, type, dollars, created
      FROM payment_sessions WHERE email = %s""", 
      (email, )
    )
    return list(map(
      lambda x: dict(id=x[0], type=x[1], dollars=x[2], created=x[3]), 
      self.cursor.fetchall()
    ))

  def payment_session_redirect(self, email, id):
    self.cursor.execute("SELECT redirected FROM payment_sessions WHERE email = %s AND id = %s", 
      (email, id)
    )
    row = self.cursor.fetchone()
    if row:
      self.cursor.execute("UPDATE payment_sessions SET redirected = TRUE WHERE email = %s AND id = %s", 
        (email, id)
      )
      self.connection.commit()
      return row[0]
    
    return None


  def consume_payment_session(self, payment_type, id, dollars):
    self.cursor.execute("SELECT email, dollars FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type))
    row = self.cursor.fetchone()
    if row:
      if int(row[1]) != int(dollars):
        current_app.logger.warning(f"""
          {payment_type} gave us a completed payment session with a different dollar amount than what we had recorded!!
          id: {id}
          account: {row[0]}
          Recorded amount: {int(row[1])}
          {payment_type} sent: {int(dollars)}
        """)
        # not sure what to do here. For now just log and do nothing.
      self.cursor.execute( "DELETE FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type) )
      self.cursor.execute( "INSERT INTO payments (email, dollars) VALUES (%s, %s) RETURNING id", (row[0], row[1]) )

      if payment_type == "btcpay":
        payment_id = self.cursor.fetchone()[0]
        self.cursor.execute(
          "INSERT INTO unresolved_btcpay_invoices (id, email, payment_id) VALUES (%s, %s, %s)",
           (id, row[0], payment_id)
        )

      self.connection.commit()
      return row[0]
    else:
      return None

  def delete_payment_session(self, payment_type, id):
    self.cursor.execute( "DELETE FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type) )
    self.connection.commit()

  def btcpay_invoice_resolved(self, id, completed):
    self.cursor.execute("SELECT email, payment_id FROM unresolved_btcpay_invoices WHERE id = %s ", (id,))
    row = self.cursor.fetchone()
    if row:
      self.cursor.execute( "DELETE FROM unresolved_btcpay_invoices WHERE id = %s", (id,) )
      if not completed:
        self.cursor.execute("UPDATE payments SET invalidated = TRUE WHERE email = %s id = %s", (row[0], row[1]))
      
      self.connection.commit()
        

  def get_unresolved_btcpay_invoices(self):
    self.cursor.execute("""
      SELECT unresolved_btcpay_invoices.id, payments.created, payments.dollars, unresolved_btcpay_invoices.email 
      FROM unresolved_btcpay_invoices JOIN payments on payment_id = payments.id
    """)
    return list(map(lambda row: dict(id=row[0], created=row[1], dollars=row[2], email=row[3]), self.cursor.fetchall()))

  def get_account_balance_warning(self, email):
    self.cursor.execute("SELECT account_balance_warning FROM accounts WHERE email = %s", (email,))
    return self.cursor.fetchone()[0]

  def set_account_balance_warning(self, email, account_balance_warning):
    self.cursor.execute("UPDATE accounts SET account_balance_warning = %s WHERE email = %s", (account_balance_warning, email))
    self.connection.commit()

  def all_accounts(self):
    self.cursor.execute("SELECT email, account_balance_warning FROM accounts WHERE ever_logged_in = TRUE ")
    return list(map(lambda row: dict(email=row[0], account_balance_warning=row[1]), self.cursor.fetchall()))


  #     ------    HOSTS     ---------  

  def authorized_for_host(self, id, token) -> bool:
    self.cursor.execute("SELECT id FROM hosts WHERE id = %s AND token = %s", (id, token))
    return self.cursor.fetchone() != None

  def host_heartbeat(self, id) -> None:
    self.cursor.execute("UPDATE hosts SET last_health_check = NOW() WHERE id = %s", (id,))
    self.connection.commit()

  def get_all_hosts(self) -> List[OnlineHost]:
    self.cursor.execute("SELECT id, https_url FROM hosts")
    return list(map(lambda x: OnlineHost(id=x[0], url=x[1]), self.cursor.fetchall()))

  def get_online_hosts(self) -> List[OnlineHost]:
    self.cursor.execute("SELECT id, https_url FROM hosts WHERE last_health_check > NOW() - INTERVAL '20 seconds'")
    return list(map(lambda x: OnlineHost(id=x[0], url=x[1]), self.cursor.fetchall()))

  def create_operation(self, online_hosts: List[OnlineHost], email: str, payload: str) -> int:

    self.cursor.execute( "INSERT INTO operations (email, payload) VALUES (%s, %s) RETURNING id", (email, payload) )
    operation_id = self.cursor.fetchone()[0]
  
    for host in online_hosts:
      self.cursor.execute( "INSERT INTO host_operation (host, operation) VALUES (%s, %s)", (host.id, operation_id) )
      
    self.connection.commit()
    return operation_id

  def update_host_operation(self, host_id: str, operation_id: int, assignment_status: str, result: str):
    if assignment_status and not result:
      self.cursor.execute(
        "UPDATE host_operation SET assignment_status = %s, assigned = NOW() WHERE host = %s AND operation = %s", 
        (assignment_status, host_id, operation_id)
      )
    elif not assignment_status and result:
      self.cursor.execute(
        "UPDATE host_operation SET results = %s, completed = NOW() WHERE host = %s AND operation = %s", 
        (result, host_id, operation_id)
      )
    elif assignment_status and result:
      self.cursor.execute(
        "UPDATE host_operation SET assignment_status = %s, assigned = NOW(), results = %s, completed = NOW() WHERE host = %s AND operation = %s", 
        (assignment_status, result, host_id, operation_id)
      )
    self.connection.commit()

  def host_of_capsul(self, capsul_id: str) -> OnlineHost:
    self.cursor.execute("SELECT hosts.id, hosts.https_url from vms JOIN hosts on hosts.id = vms.host where vms.id = %s",  (capsul_id,))
    row = self.cursor.fetchone()
    if row:
      return OnlineHost(row[0], row[1])
    else:
      return None

  def host_operation_exists(self, operation_id: int, host_id: str) -> bool:
    self.cursor.execute("SELECT operation FROM host_operation WHERE host = %s AND operation = %s",(host_id, operation_id))
    return len(self.cursor.fetchall()) != 0

  def claim_operation(self, operation_id: int, host_id: str) -> bool:
    # have to make a new cursor to set isolation level
    # cursor = self.connection.cursor()
    self.cursor.execute("""
      BEGIN TRANSACTION;
      UPDATE host_operation SET assignment_status = 'assigned' 
        WHERE host = %s AND operation = %s AND operation != (
          SELECT COALESCE(
            (SELECT operation FROM host_operation WHERE operation = %s AND assignment_status = 'assigned'),
            -1
          ) as already_assigned_operation_id
        );
      COMMIT TRANSACTION;
    """,  (host_id, operation_id, operation_id))

    to_return = self.cursor.rowcount != 0

    self.connection.commit()
    #cursor.close()

    return to_return