import re # I was never able to get this type hinting to work correctly # from psycopg2.extensions import connection as Psycopg2Connection, cursor as Psycopg2Cursor from nanoid import generate from flask import current_app from typing import List from capsulflask.shared import OnlineHost class DBModel: #def __init__(self, connection: Psycopg2Connection, cursor: Psycopg2Cursor): def __init__(self, connection, cursor): self.connection = connection self.cursor = cursor # ------ LOGIN --------- def login(self, email): self.cursor.execute("SELECT * FROM accounts WHERE email = %s", (email, )) hasExactMatch = len(self.cursor.fetchall()) self.cursor.execute("SELECT * FROM accounts WHERE email = %s AND ever_logged_in = TRUE", (email, )) everLoggedIn = len(self.cursor.fetchall()) ignoreCaseMatches = [] if everLoggedIn == 0: self.cursor.execute("SELECT email FROM accounts WHERE lower_case_email = %s AND email != %s", (email.lower(), email)) ignoreCaseMatches = list(map(lambda x: x[0], self.cursor.fetchall())) if hasExactMatch == 0: self.cursor.execute("INSERT INTO accounts (email, lower_case_email) VALUES (%s, %s)", (email, email.lower())) self.cursor.execute("SELECT token FROM login_tokens WHERE email = %s and created > (NOW() - INTERVAL '20 min')", (email, )) if len(self.cursor.fetchall()) > 2: return (None, ignoreCaseMatches) token = generate() self.cursor.execute("INSERT INTO login_tokens (email, token) VALUES (%s, %s)", (email, token)) self.connection.commit() return (token, ignoreCaseMatches) def consume_token(self, token): self.cursor.execute("SELECT email FROM login_tokens WHERE token = %s and created > (NOW() - INTERVAL '20 min')", (token, )) row = self.cursor.fetchone() if row: email = row[0] self.cursor.execute("DELETE FROM login_tokens WHERE email = %s", (email, )) self.cursor.execute("UPDATE accounts SET ever_logged_in = TRUE WHERE email = %s", (email, )) self.connection.commit() return email return None # ------ VM & ACCOUNT MANAGEMENT --------- def non_deleted_vms_by_host_and_network(self, host_id): query = "SELECT id, email, desired_state, host, network_name, public_ipv4, public_ipv6 FROM vms WHERE deleted IS NULL" if host_id is None: self.cursor.execute(query) else: if not re.match(r"^[a-zA-Z0-9_-]+$", host_id): raise ValueError(f"host_id \"{host_id}\" must match \"^[a-zA-Z0-9_-]+\"") # I kept getting "TypeError: not all arguments converted during string formatting" # when I was trying to mix python string templating with psycopg2 safe parameter passing. # so i just did all of it in python and check the user-provided data for safety myself (no sql injection). self.cursor.execute(f"{query} AND host = '{host_id}'") hosts = dict() for row in self.cursor.fetchall(): host_id = row[3] network_name = row[4] if host_id not in hosts: hosts[host_id] = dict() if network_name not in hosts[host_id]: hosts[host_id][network_name] = [] hosts[host_id][network_name].append( dict(id=row[0], email=row[1], host=host_id, network_name=network_name, desired_state=row[2], public_ipv4=row[5], public_ipv6=row[6]) ) return hosts def set_desired_state(self, email, vm_id, desired_state): self.cursor.execute("UPDATE vms SET desired_state = %s WHERE email = %s AND id = %s", (desired_state, email, vm_id)) self.connection.commit() def all_accounts_with_active_vms(self): self.cursor.execute("SELECT DISTINCT email FROM vms WHERE deleted IS NULL") return list(map(lambda x: x[0], self.cursor.fetchall())) def operating_systems_dict(self): self.cursor.execute("SELECT id, template_image_file_name, description FROM os_images WHERE deprecated = FALSE") operatingSystems = dict() for row in self.cursor.fetchall(): operatingSystems[row[0]] = dict(template_image_file_name=row[1], description=row[2]) return operatingSystems def vm_sizes_dict(self): self.cursor.execute("SELECT id, dollars_per_month, vcpus, memory_mb, bandwidth_gb_per_month FROM vm_sizes") vmSizes = dict() for row in self.cursor.fetchall(): vmSizes[row[0]] = dict(dollars_per_month=row[1], vcpus=row[2], memory_mb=row[3], bandwidth_gb_per_month=row[4]) return vmSizes def list_ssh_public_keys_for_account(self, email): self.cursor.execute("SELECT name, content, created FROM ssh_public_keys WHERE email = %s", (email, )) return list(map( lambda x: dict(name=x[0], content=x[1], created=x[2]), self.cursor.fetchall() )) def ssh_public_key_name_exists(self, email, name): self.cursor.execute( "SELECT name FROM ssh_public_keys where email = %s AND name = %s", (email, name) ) return len(self.cursor.fetchall()) > 0 def create_ssh_public_key(self, email, name, content): self.cursor.execute(""" INSERT INTO ssh_public_keys (email, name, content) VALUES (%s, %s, %s) """, (email, name, content) ) self.connection.commit() def delete_ssh_public_key(self, email, name): self.cursor.execute( "DELETE FROM ssh_public_keys where email = %s AND name = %s", (email, name) ) self.connection.commit() def list_vms_for_account(self, email): self.cursor.execute(""" SELECT vms.id, vms.public_ipv4, vms.public_ipv6, vms.size, vms.shortterm, vms.os, vms.created, vms.deleted, vm_sizes.dollars_per_month FROM vms JOIN vm_sizes on vms.size = vm_sizes.id WHERE vms.email = %s""", (email, ) ) return list(map( lambda x: dict(id=x[0], ipv4=x[1], ipv6=x[2], size=x[3], shortterm=x[4], os=x[5], created=x[6], deleted=x[7], dollars_per_month=x[8]), self.cursor.fetchall() )) def update_vm_ip(self, email, id, ipv4): self.cursor.execute("UPDATE vms SET public_ipv4 = %s WHERE email = %s AND id = %s", (ipv4, email, id)) self.connection.commit() def update_vm_ssh_host_keys(self, email, id, ssh_host_keys): for key in ssh_host_keys: self.cursor.execute(""" INSERT INTO vm_ssh_host_key (email, vm_id, key_type, content, sha256) VALUES (%s, %s, %s, %s, %s) """, (email, id, key['key_type'], key['content'], key['sha256']) ) self.connection.commit() def create_vm(self, email, id, size, shortterm, os, host, network_name, public_ipv4, ssh_authorized_keys): self.cursor.execute(""" INSERT INTO vms (email, id, size, shortterm, os, host, network_name, public_ipv4) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, (email, id, size, shortterm, os, host, network_name, public_ipv4) ) for ssh_authorized_key in ssh_authorized_keys: self.cursor.execute(""" INSERT INTO vm_ssh_authorized_key (email, vm_id, ssh_public_key_name) VALUES (%s, %s, %s) """, (email, id, ssh_authorized_key) ) self.connection.commit() def delete_vm(self, email, id): self.cursor.execute("UPDATE vms SET deleted = now() WHERE email = %s AND id = %s", ( email, id)) self.connection.commit() def get_vm_detail(self, email, id): self.cursor.execute(""" SELECT vms.id, vms.public_ipv4, vms.public_ipv6, os_images.description, vms.created, vms.deleted, vm_sizes.id, vms.shortterm, vm_sizes.dollars_per_month, vm_sizes.vcpus, vm_sizes.memory_mb, vm_sizes.bandwidth_gb_per_month FROM vms JOIN os_images on vms.os = os_images.id JOIN vm_sizes on vms.size = vm_sizes.id WHERE vms.email = %s AND vms.id = %s""", (email, id) ) row = self.cursor.fetchone() if not row: return None vm = dict( id=row[0], ipv4=row[1], ipv6=row[2], os_description=row[3], created=row[4], deleted=row[5], size=row[6], shortterm=row[7], dollars_per_month=row[8], vcpus=row[9], memory_mb=row[10], bandwidth_gb_per_month=row[11], ) self.cursor.execute(""" SELECT ssh_public_key_name FROM vm_ssh_authorized_key WHERE vm_ssh_authorized_key.email = %s AND vm_ssh_authorized_key.vm_id = %s""", (email, id) ) vm["ssh_authorized_keys"] = list(map( lambda x: x[0], self.cursor.fetchall() )) self.cursor.execute(""" SELECT key_type, content, sha256 FROM vm_ssh_host_key WHERE vm_ssh_host_key.email = %s AND vm_ssh_host_key.vm_id = %s""", (email, id) ) vm["ssh_host_keys"] = list(map( lambda x: dict(key_type=x[0], content=x[1], sha256=x[2]), self.cursor.fetchall() )) return vm def clear_shortterm_flag(self, email): self.cursor.execute("UPDATE vms SET shortterm = FALSE WHERE email = %s", (email,)) self.connection.commit() # ------ PAYMENTS & ACCOUNT BALANCE --------- def list_payments_for_account(self, email): self.cursor.execute(""" SELECT id, dollars, invalidated, created FROM payments WHERE payments.email = %s""", (email, ) ) return list(map( lambda x: dict(id=x[0], dollars=x[1], invalidated=x[2], created=x[3]), self.cursor.fetchall() )) def create_payment_session(self, payment_type, id, email, dollars): self.cursor.execute(""" INSERT INTO payment_sessions (id, type, email, dollars) VALUES (%s, %s, %s, %s) """, (id, payment_type, email, dollars) ) self.connection.commit() def list_payment_sessions_for_account(self, email): self.cursor.execute(""" SELECT id, type, dollars, created FROM payment_sessions WHERE email = %s""", (email, ) ) return list(map( lambda x: dict(id=x[0], type=x[1], dollars=x[2], created=x[3]), self.cursor.fetchall() )) def payment_session_redirect(self, email, id): self.cursor.execute("SELECT redirected FROM payment_sessions WHERE email = %s AND id = %s", (email, id) ) row = self.cursor.fetchone() if row: self.cursor.execute("UPDATE payment_sessions SET redirected = TRUE WHERE email = %s AND id = %s", (email, id) ) self.connection.commit() return row[0] return None def consume_payment_session(self, payment_type, id, dollars): self.cursor.execute("SELECT email, dollars FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type)) row = self.cursor.fetchone() if row: if int(row[1]) != int(dollars): current_app.logger.warning(f""" {payment_type} gave us a completed payment session with a different dollar amount than what we had recorded!! id: {id} account: {row[0]} Recorded amount: {int(row[1])} {payment_type} sent: {int(dollars)} """) # not sure what to do here. For now just log and do nothing. self.cursor.execute( "DELETE FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type) ) self.cursor.execute( "INSERT INTO payments (email, dollars) VALUES (%s, %s) RETURNING id", (row[0], row[1]) ) if payment_type == "btcpay": payment_id = self.cursor.fetchone()[0] self.cursor.execute( "INSERT INTO unresolved_btcpay_invoices (id, email, payment_id) VALUES (%s, %s, %s)", (id, row[0], payment_id) ) self.connection.commit() return row[0] else: return None def delete_payment_session(self, payment_type, id): self.cursor.execute( "DELETE FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type) ) self.connection.commit() def btcpay_invoice_resolved(self, id, completed) -> str: self.cursor.execute("SELECT email, payment_id FROM unresolved_btcpay_invoices WHERE id = %s ", (id,)) row = self.cursor.fetchone() if row: self.cursor.execute( "DELETE FROM unresolved_btcpay_invoices WHERE id = %s", (id,) ) if not completed: self.cursor.execute("UPDATE payments SET invalidated = TRUE WHERE email = %s id = %s", (row[0], row[1])) self.connection.commit() if completed: return row[0] return None def get_unresolved_btcpay_invoices(self): self.cursor.execute(""" SELECT unresolved_btcpay_invoices.id, payments.created, payments.dollars, unresolved_btcpay_invoices.email FROM unresolved_btcpay_invoices JOIN payments on payment_id = payments.id """) return list(map(lambda row: dict(id=row[0], created=row[1], dollars=row[2], email=row[3]), self.cursor.fetchall())) def get_account_balance_warning(self, email): self.cursor.execute("SELECT account_balance_warning FROM accounts WHERE email = %s", (email,)) return self.cursor.fetchone()[0] def set_account_balance_warning(self, email, account_balance_warning): self.cursor.execute("UPDATE accounts SET account_balance_warning = %s WHERE email = %s", (account_balance_warning, email)) self.connection.commit() def all_accounts(self): self.cursor.execute("SELECT email, account_balance_warning FROM accounts WHERE ever_logged_in = TRUE ") return list(map(lambda row: dict(email=row[0], account_balance_warning=row[1]), self.cursor.fetchall())) # ------ HOSTS --------- def list_hosts_with_networks(self, host_id: str): query = """ SELECT hosts.id, hosts.last_health_check, host_network.network_name, host_network.virtual_bridge_name, host_network.public_ipv4_cidr_block, host_network.public_ipv4_first_usable_ip, host_network.public_ipv4_last_usable_ip FROM hosts JOIN host_network ON host_network.host = hosts.id """ if host_id is None: self.cursor.execute(query) else: if not re.match(r"^[a-zA-Z0-9_-]+$", host_id): raise ValueError(f"host_id \"{host_id}\" must match \"^[a-zA-Z0-9_-]+\"") # I kept getting "TypeError: not all arguments converted during string formatting" # when I was trying to mix python query string templating with psycopg2 safe parameter passing. # so i just did all of it in python and check the user-provided data for safety myself (no sql injection). self.cursor.execute(f"{query} WHERE hosts.id = '{host_id}'") hosts = dict() for row in self.cursor.fetchall(): if row[0] not in hosts: hosts[row[0]] = dict(last_health_check=row[1], networks=[]) hosts[row[0]]["networks"].append(dict( network_name=row[2], virtual_bridge_name=row[3], public_ipv4_cidr_block=row[4], public_ipv4_first_usable_ip=row[5], public_ipv4_last_usable_ip=row[6] )) return hosts def authorized_for_host(self, id, token) -> bool: self.cursor.execute("SELECT id FROM hosts WHERE id = %s AND token = %s", (id, token)) return self.cursor.fetchone() != None def host_heartbeat(self, id) -> None: self.cursor.execute("UPDATE hosts SET last_health_check = NOW() WHERE id = %s", (id,)) self.connection.commit() def get_all_hosts(self) -> List[OnlineHost]: self.cursor.execute("SELECT id, https_url FROM hosts") return list(map(lambda x: OnlineHost(id=x[0], url=x[1]), self.cursor.fetchall())) def get_online_hosts(self) -> List[OnlineHost]: self.cursor.execute("SELECT id, https_url FROM hosts WHERE last_health_check > NOW() - INTERVAL '20 seconds'") return list(map(lambda x: OnlineHost(id=x[0], url=x[1]), self.cursor.fetchall())) def list_all_operations(self): self.cursor.execute(""" SELECT operations.id, operations.email, operations.created, operations.payload, host_operation.host host_operation.assignment_status, host_operation.assigned, host_operation.completed, host_operation.results FROM operations JOIN host_operation ON host_operation.operation = operations.id """) operations = dict() for row in self.cursor.fetchall(): if row[0] not in operations: operations[row[0]] = dict(email=row[1], created=row[2], payload=row[3], hosts=[]) operations[row[0]]["hosts"].append(dict( host=row[4], assignment_status=row[5], assigned=row[6], completed=row[7], results=row[8], )) return operations def create_operation(self, online_hosts: List[OnlineHost], email: str, payload: str) -> int: self.cursor.execute( "INSERT INTO operations (email, payload) VALUES (%s, %s) RETURNING id", (email, payload) ) operation_id = self.cursor.fetchone()[0] for host in online_hosts: self.cursor.execute( "INSERT INTO host_operation (host, operation) VALUES (%s, %s)", (host.id, operation_id) ) self.connection.commit() return operation_id def update_operation(self, operation_id: int, payload: str): self.cursor.execute( "UPDATE operations SET payload = %s WHERE id = %s", (payload, operation_id) ) self.connection.commit() def update_host_operation(self, host_id: str, operation_id: int, assignment_status: str, result: str): if assignment_status and not result: self.cursor.execute( "UPDATE host_operation SET assignment_status = %s, assigned = NOW() WHERE host = %s AND operation = %s", (assignment_status, host_id, operation_id) ) elif not assignment_status and result: self.cursor.execute( "UPDATE host_operation SET results = %s, completed = NOW() WHERE host = %s AND operation = %s", (result, host_id, operation_id) ) elif assignment_status and result: self.cursor.execute( "UPDATE host_operation SET assignment_status = %s, assigned = NOW(), results = %s, completed = NOW() WHERE host = %s AND operation = %s", (assignment_status, result, host_id, operation_id) ) self.connection.commit() def host_by_id(self, host_id: str) -> OnlineHost: self.cursor.execute("SELECT hosts.id, hosts.https_url FROM hosts WHERE hosts.id = %s", (host_id,)) row = self.cursor.fetchone() if row: return OnlineHost(row[0], row[1]) else: return None def host_of_capsul(self, capsul_id: str) -> OnlineHost: self.cursor.execute("SELECT hosts.id, hosts.https_url from vms JOIN hosts on hosts.id = vms.host where vms.id = %s", (capsul_id,)) row = self.cursor.fetchone() if row: return OnlineHost(row[0], row[1]) else: return None def get_payload_json_from_host_operation(self, operation_id: int, host_id: str) -> str: self.cursor.execute( """ SELECT operations.payload FROM operations JOIN host_operation ON host_operation.operation = operations.id WHERE host_operation.host = %s AND host_operation.operation = %s """, (host_id, operation_id) ) row = self.cursor.fetchone() if row: return row[0] else: return None def claim_operation(self, operation_id: int, host_id: str) -> bool: # have to make a new cursor to set isolation level # cursor = self.connection.cursor() # self.cursor.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;") # psycopg2.errors.ActiveSqlTransaction: SET TRANSACTION ISOLATION LEVEL must be called before any query self.cursor.execute(""" BEGIN TRANSACTION; UPDATE host_operation SET assignment_status = 'assigned' WHERE host = %s AND operation = %s AND operation != ( SELECT COALESCE( (SELECT operation FROM host_operation WHERE operation = %s AND assignment_status = 'assigned'), -1 ) as already_assigned_operation_id ); COMMIT TRANSACTION; """, (host_id, operation_id, operation_id)) to_return = self.cursor.rowcount != 0 self.connection.commit() #cursor.close() return to_return def set_broadcast_message(self, message): self.cursor.execute("DELETE FROM broadcast_message; INSERT INTO broadcast_message (message) VALUES (%s)", (message, )) self.connection.commit()