forked from 3wordchant/capsul-flask
cf42ac5e4d
.. using server-side API tokens
517 lines
19 KiB
Python
517 lines
19 KiB
Python
import re
|
|
|
|
# I was never able to get this type hinting to work correctly
|
|
# from psycopg2.extensions import connection as Psycopg2Connection, cursor as Psycopg2Cursor
|
|
import hashlib
|
|
from nanoid import generate
|
|
from flask import current_app
|
|
from typing import List
|
|
|
|
from capsulflask.shared import OnlineHost
|
|
|
|
|
|
class DBModel:
|
|
#def __init__(self, connection: Psycopg2Connection, cursor: Psycopg2Cursor):
|
|
def __init__(self, connection, cursor):
|
|
self.connection = connection
|
|
self.cursor = cursor
|
|
|
|
|
|
# ------ LOGIN ---------
|
|
|
|
|
|
def login(self, email):
|
|
self.cursor.execute("SELECT * FROM accounts WHERE email = %s", (email, ))
|
|
hasExactMatch = len(self.cursor.fetchall())
|
|
self.cursor.execute("SELECT * FROM accounts WHERE email = %s AND ever_logged_in = TRUE", (email, ))
|
|
everLoggedIn = len(self.cursor.fetchall())
|
|
ignoreCaseMatches = []
|
|
if everLoggedIn == 0:
|
|
self.cursor.execute("SELECT email FROM accounts WHERE lower_case_email = %s AND email != %s", (email.lower(), email))
|
|
ignoreCaseMatches = list(map(lambda x: x[0], self.cursor.fetchall()))
|
|
|
|
if hasExactMatch == 0:
|
|
self.cursor.execute("INSERT INTO accounts (email, lower_case_email) VALUES (%s, %s)", (email, email.lower()))
|
|
|
|
self.cursor.execute("SELECT token FROM login_tokens WHERE email = %s and created > (NOW() - INTERVAL '20 min')", (email, ))
|
|
if len(self.cursor.fetchall()) > 2:
|
|
return (None, ignoreCaseMatches)
|
|
|
|
token = generate()
|
|
self.cursor.execute("INSERT INTO login_tokens (email, token) VALUES (%s, %s)", (email, token))
|
|
self.connection.commit()
|
|
|
|
return (token, ignoreCaseMatches)
|
|
|
|
def authenticate_token(self, token):
|
|
m = hashlib.md5()
|
|
m.update(token.encode('utf-8'))
|
|
hash_token = m.hexdigest()
|
|
self.cursor.execute("SELECT email FROM api_tokens WHERE token = %s", (hash_token, ))
|
|
result = self.cursor.fetchall()
|
|
if len(result) == 1:
|
|
return result[0]
|
|
return None
|
|
|
|
def consume_token(self, token):
|
|
self.cursor.execute("SELECT email FROM login_tokens WHERE token = %s and created > (NOW() - INTERVAL '20 min')", (token, ))
|
|
row = self.cursor.fetchone()
|
|
if row:
|
|
email = row[0]
|
|
self.cursor.execute("DELETE FROM login_tokens WHERE email = %s", (email, ))
|
|
self.cursor.execute("UPDATE accounts SET ever_logged_in = TRUE WHERE email = %s", (email, ))
|
|
self.connection.commit()
|
|
return email
|
|
return None
|
|
|
|
|
|
# ------ VM & ACCOUNT MANAGEMENT ---------
|
|
|
|
def non_deleted_vms_by_host_and_network(self, host_id):
|
|
query = "SELECT id, host, network_name, public_ipv4, public_ipv6 FROM vms WHERE deleted IS NULL"
|
|
if host_id is None:
|
|
self.cursor.execute(query)
|
|
else:
|
|
if not re.match(r"^[a-zA-Z0-9_-]+$", host_id):
|
|
raise ValueError(f"host_id \"{host_id}\" must match \"^[a-zA-Z0-9_-]+\"")
|
|
|
|
# I kept getting "TypeError: not all arguments converted during string formatting"
|
|
# when I was trying to mix python string templating with psycopg2 safe parameter passing.
|
|
# so i just did all of it in python and check the user-provided data for safety myself (no sql injection).
|
|
self.cursor.execute(f"{query} AND host = '{host_id}'")
|
|
|
|
hosts = dict()
|
|
for row in self.cursor.fetchall():
|
|
host_id = row[1]
|
|
network_name = row[2]
|
|
if host_id not in hosts:
|
|
hosts[host_id] = dict()
|
|
if network_name not in hosts[host_id]:
|
|
hosts[host_id][network_name] = []
|
|
|
|
hosts[host_id][network_name].append(
|
|
dict(id=row[0], public_ipv4=row[3], public_ipv6=row[4])
|
|
)
|
|
|
|
return hosts
|
|
|
|
def all_non_deleted_vm_ids(self):
|
|
self.cursor.execute("SELECT id FROM vms WHERE deleted IS NULL")
|
|
return list(map(lambda x: x[0], self.cursor.fetchall()))
|
|
|
|
def operating_systems_dict(self):
|
|
self.cursor.execute("SELECT id, template_image_file_name, description FROM os_images WHERE deprecated = FALSE")
|
|
|
|
operatingSystems = dict()
|
|
for row in self.cursor.fetchall():
|
|
operatingSystems[row[0]] = dict(template_image_file_name=row[1], description=row[2])
|
|
|
|
return operatingSystems
|
|
|
|
def vm_sizes_dict(self):
|
|
self.cursor.execute("SELECT id, dollars_per_month, vcpus, memory_mb, bandwidth_gb_per_month FROM vm_sizes")
|
|
|
|
vmSizes = dict()
|
|
for row in self.cursor.fetchall():
|
|
vmSizes[row[0]] = dict(dollars_per_month=row[1], vcpus=row[2], memory_mb=row[3], bandwidth_gb_per_month=row[4])
|
|
|
|
return vmSizes
|
|
|
|
def list_ssh_public_keys_for_account(self, email):
|
|
self.cursor.execute("SELECT name, content, created FROM ssh_public_keys WHERE email = %s", (email, ))
|
|
return list(map(
|
|
lambda x: dict(name=x[0], content=x[1], created=x[2]),
|
|
self.cursor.fetchall()
|
|
))
|
|
|
|
def ssh_public_key_name_exists(self, email, name):
|
|
self.cursor.execute( "SELECT name FROM ssh_public_keys where email = %s AND name = %s", (email, name) )
|
|
return len(self.cursor.fetchall()) > 0
|
|
|
|
def create_ssh_public_key(self, email, name, content):
|
|
self.cursor.execute("""
|
|
INSERT INTO ssh_public_keys (email, name, content)
|
|
VALUES (%s, %s, %s)
|
|
""",
|
|
(email, name, content)
|
|
)
|
|
self.connection.commit()
|
|
|
|
def delete_ssh_public_key(self, email, name):
|
|
self.cursor.execute( "DELETE FROM ssh_public_keys where email = %s AND name = %s", (email, name) )
|
|
self.connection.commit()
|
|
|
|
def list_api_tokens(self, email):
|
|
self.cursor.execute(
|
|
"SELECT id, token, name, created FROM api_tokens WHERE email = %s",
|
|
(email, )
|
|
)
|
|
return list(map(
|
|
lambda x: dict(id=x[0], token=x[1], name=x[2], created=x[3]),
|
|
self.cursor.fetchall()
|
|
))
|
|
|
|
def generate_api_token(self, email, name):
|
|
token = generate()
|
|
m = hashlib.md5()
|
|
m.update(token.encode('utf-8'))
|
|
hash_token = m.hexdigest()
|
|
self.cursor.execute(
|
|
"INSERT INTO api_tokens (email, name, token) VALUES (%s, %s, %s)",
|
|
(email, name, hash_token)
|
|
)
|
|
self.connection.commit()
|
|
return token
|
|
|
|
def delete_api_token(self, email, id_):
|
|
self.cursor.execute( "DELETE FROM api_tokens where email = %s AND id = %s", (email, id_))
|
|
self.connection.commit()
|
|
|
|
def list_vms_for_account(self, email):
|
|
self.cursor.execute("""
|
|
SELECT vms.id, vms.public_ipv4, vms.public_ipv6, vms.size, vms.os, vms.created, vms.deleted, vm_sizes.dollars_per_month
|
|
FROM vms JOIN vm_sizes on vms.size = vm_sizes.id
|
|
WHERE vms.email = %s""",
|
|
(email, )
|
|
)
|
|
return list(map(
|
|
lambda x: dict(id=x[0], ipv4=x[1], ipv6=x[2], size=x[3], os=x[4], created=x[5], deleted=x[6], dollars_per_month=x[7]),
|
|
self.cursor.fetchall()
|
|
))
|
|
|
|
def update_vm_ip(self, email, id, ipv4):
|
|
self.cursor.execute("UPDATE vms SET public_ipv4 = %s WHERE email = %s AND id = %s", (ipv4, email, id))
|
|
self.connection.commit()
|
|
|
|
def update_vm_ssh_host_keys(self, email, id, ssh_host_keys):
|
|
for key in ssh_host_keys:
|
|
self.cursor.execute("""
|
|
INSERT INTO vm_ssh_host_key (email, vm_id, key_type, content, sha256)
|
|
VALUES (%s, %s, %s, %s, %s)
|
|
""",
|
|
(email, id, key['key_type'], key['content'], key['sha256'])
|
|
)
|
|
self.connection.commit()
|
|
|
|
def create_vm(self, email, id, size, os, host, network_name, public_ipv4, ssh_authorized_keys):
|
|
self.cursor.execute("""
|
|
INSERT INTO vms (email, id, size, os, host, network_name, public_ipv4)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
""",
|
|
(email, id, size, os, host, network_name, public_ipv4)
|
|
)
|
|
|
|
for ssh_authorized_key in ssh_authorized_keys:
|
|
self.cursor.execute("""
|
|
INSERT INTO vm_ssh_authorized_key (email, vm_id, ssh_public_key_name)
|
|
VALUES (%s, %s, %s)
|
|
""",
|
|
(email, id, ssh_authorized_key)
|
|
)
|
|
self.connection.commit()
|
|
|
|
def delete_vm(self, email, id):
|
|
self.cursor.execute("UPDATE vms SET deleted = now() WHERE email = %s AND id = %s", ( email, id))
|
|
self.connection.commit()
|
|
|
|
def get_vm_detail(self, email, id):
|
|
self.cursor.execute("""
|
|
SELECT vms.id, vms.public_ipv4, vms.public_ipv6, os_images.description, vms.created, vms.deleted,
|
|
vm_sizes.id, vm_sizes.dollars_per_month, vm_sizes.vcpus, vm_sizes.memory_mb, vm_sizes.bandwidth_gb_per_month
|
|
FROM vms
|
|
JOIN os_images on vms.os = os_images.id
|
|
JOIN vm_sizes on vms.size = vm_sizes.id
|
|
WHERE vms.email = %s AND vms.id = %s""",
|
|
(email, id)
|
|
)
|
|
row = self.cursor.fetchone()
|
|
if not row:
|
|
return None
|
|
|
|
vm = dict(
|
|
id=row[0], ipv4=row[1], ipv6=row[2], os_description=row[3], created=row[4], deleted=row[5],
|
|
size=row[6], dollars_per_month=row[7], vcpus=row[8], memory_mb=row[9], bandwidth_gb_per_month=row[10]
|
|
)
|
|
|
|
self.cursor.execute("""
|
|
SELECT ssh_public_key_name FROM vm_ssh_authorized_key
|
|
WHERE vm_ssh_authorized_key.email = %s AND vm_ssh_authorized_key.vm_id = %s""",
|
|
(email, id)
|
|
)
|
|
vm["ssh_authorized_keys"] = list(map( lambda x: x[0], self.cursor.fetchall() ))
|
|
|
|
|
|
self.cursor.execute("""
|
|
SELECT key_type, content, sha256 FROM vm_ssh_host_key
|
|
WHERE vm_ssh_host_key.email = %s AND vm_ssh_host_key.vm_id = %s""",
|
|
(email, id)
|
|
)
|
|
vm["ssh_host_keys"] = list(map( lambda x: dict(key_type=x[0], content=x[1], sha256=x[2]), self.cursor.fetchall() ))
|
|
|
|
return vm
|
|
|
|
|
|
# ------ PAYMENTS & ACCOUNT BALANCE ---------
|
|
|
|
|
|
def list_payments_for_account(self, email):
|
|
self.cursor.execute("""
|
|
SELECT id, dollars, invalidated, created
|
|
FROM payments WHERE payments.email = %s""",
|
|
(email, )
|
|
)
|
|
return list(map(
|
|
lambda x: dict(id=x[0], dollars=x[1], invalidated=x[2], created=x[3]),
|
|
self.cursor.fetchall()
|
|
))
|
|
|
|
def create_payment_session(self, payment_type, id, email, dollars):
|
|
self.cursor.execute("""
|
|
INSERT INTO payment_sessions (id, type, email, dollars)
|
|
VALUES (%s, %s, %s, %s)
|
|
""",
|
|
(id, payment_type, email, dollars)
|
|
)
|
|
self.connection.commit()
|
|
|
|
def list_payment_sessions_for_account(self, email):
|
|
self.cursor.execute("""
|
|
SELECT id, type, dollars, created
|
|
FROM payment_sessions WHERE email = %s""",
|
|
(email, )
|
|
)
|
|
return list(map(
|
|
lambda x: dict(id=x[0], type=x[1], dollars=x[2], created=x[3]),
|
|
self.cursor.fetchall()
|
|
))
|
|
|
|
def payment_session_redirect(self, email, id):
|
|
self.cursor.execute("SELECT redirected FROM payment_sessions WHERE email = %s AND id = %s",
|
|
(email, id)
|
|
)
|
|
row = self.cursor.fetchone()
|
|
if row:
|
|
self.cursor.execute("UPDATE payment_sessions SET redirected = TRUE WHERE email = %s AND id = %s",
|
|
(email, id)
|
|
)
|
|
self.connection.commit()
|
|
return row[0]
|
|
|
|
return None
|
|
|
|
|
|
def consume_payment_session(self, payment_type, id, dollars):
|
|
self.cursor.execute("SELECT email, dollars FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type))
|
|
row = self.cursor.fetchone()
|
|
if row:
|
|
if int(row[1]) != int(dollars):
|
|
current_app.logger.warning(f"""
|
|
{payment_type} gave us a completed payment session with a different dollar amount than what we had recorded!!
|
|
id: {id}
|
|
account: {row[0]}
|
|
Recorded amount: {int(row[1])}
|
|
{payment_type} sent: {int(dollars)}
|
|
""")
|
|
# not sure what to do here. For now just log and do nothing.
|
|
self.cursor.execute( "DELETE FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type) )
|
|
self.cursor.execute( "INSERT INTO payments (email, dollars) VALUES (%s, %s) RETURNING id", (row[0], row[1]) )
|
|
|
|
if payment_type == "btcpay":
|
|
payment_id = self.cursor.fetchone()[0]
|
|
self.cursor.execute(
|
|
"INSERT INTO unresolved_btcpay_invoices (id, email, payment_id) VALUES (%s, %s, %s)",
|
|
(id, row[0], payment_id)
|
|
)
|
|
|
|
self.connection.commit()
|
|
return row[0]
|
|
else:
|
|
return None
|
|
|
|
def delete_payment_session(self, payment_type, id):
|
|
self.cursor.execute( "DELETE FROM payment_sessions WHERE id = %s AND type = %s", (id, payment_type) )
|
|
self.connection.commit()
|
|
|
|
def btcpay_invoice_resolved(self, id, completed):
|
|
self.cursor.execute("SELECT email, payment_id FROM unresolved_btcpay_invoices WHERE id = %s ", (id,))
|
|
row = self.cursor.fetchone()
|
|
if row:
|
|
self.cursor.execute( "DELETE FROM unresolved_btcpay_invoices WHERE id = %s", (id,) )
|
|
if not completed:
|
|
self.cursor.execute("UPDATE payments SET invalidated = TRUE WHERE email = %s id = %s", (row[0], row[1]))
|
|
|
|
self.connection.commit()
|
|
|
|
|
|
def get_unresolved_btcpay_invoices(self):
|
|
self.cursor.execute("""
|
|
SELECT unresolved_btcpay_invoices.id, payments.created, payments.dollars, unresolved_btcpay_invoices.email
|
|
FROM unresolved_btcpay_invoices JOIN payments on payment_id = payments.id
|
|
""")
|
|
return list(map(lambda row: dict(id=row[0], created=row[1], dollars=row[2], email=row[3]), self.cursor.fetchall()))
|
|
|
|
def get_account_balance_warning(self, email):
|
|
self.cursor.execute("SELECT account_balance_warning FROM accounts WHERE email = %s", (email,))
|
|
return self.cursor.fetchone()[0]
|
|
|
|
def set_account_balance_warning(self, email, account_balance_warning):
|
|
self.cursor.execute("UPDATE accounts SET account_balance_warning = %s WHERE email = %s", (account_balance_warning, email))
|
|
self.connection.commit()
|
|
|
|
def all_accounts(self):
|
|
self.cursor.execute("SELECT email, account_balance_warning FROM accounts WHERE ever_logged_in = TRUE ")
|
|
return list(map(lambda row: dict(email=row[0], account_balance_warning=row[1]), self.cursor.fetchall()))
|
|
|
|
|
|
# ------ HOSTS ---------
|
|
|
|
def list_hosts_with_networks(self, host_id: str):
|
|
query = """
|
|
SELECT hosts.id, hosts.last_health_check, host_network.network_name,
|
|
host_network.public_ipv4_cidr_block, host_network.public_ipv4_first_usable_ip, host_network.public_ipv4_last_usable_ip
|
|
FROM hosts
|
|
JOIN host_network ON host_network.host = hosts.id
|
|
"""
|
|
if host_id is None:
|
|
self.cursor.execute(query)
|
|
else:
|
|
if not re.match(r"^[a-zA-Z0-9_-]+$", host_id):
|
|
raise ValueError(f"host_id \"{host_id}\" must match \"^[a-zA-Z0-9_-]+\"")
|
|
|
|
# I kept getting "TypeError: not all arguments converted during string formatting"
|
|
# when I was trying to mix python query string templating with psycopg2 safe parameter passing.
|
|
# so i just did all of it in python and check the user-provided data for safety myself (no sql injection).
|
|
self.cursor.execute(f"{query} WHERE hosts.id = '{host_id}'")
|
|
|
|
hosts = dict()
|
|
for row in self.cursor.fetchall():
|
|
if row[0] not in hosts:
|
|
hosts[row[0]] = dict(last_health_check=row[1], networks=[])
|
|
|
|
hosts[row[0]]["networks"].append(dict(
|
|
network_name=row[2],
|
|
public_ipv4_cidr_block=row[3],
|
|
public_ipv4_first_usable_ip=row[4],
|
|
public_ipv4_last_usable_ip=row[5]
|
|
))
|
|
|
|
return hosts
|
|
|
|
def authorized_for_host(self, id, token) -> bool:
|
|
self.cursor.execute("SELECT id FROM hosts WHERE id = %s AND token = %s", (id, token))
|
|
return self.cursor.fetchone() != None
|
|
|
|
def host_heartbeat(self, id) -> None:
|
|
self.cursor.execute("UPDATE hosts SET last_health_check = NOW() WHERE id = %s", (id,))
|
|
self.connection.commit()
|
|
|
|
def get_all_hosts(self) -> List[OnlineHost]:
|
|
self.cursor.execute("SELECT id, https_url FROM hosts")
|
|
return list(map(lambda x: OnlineHost(id=x[0], url=x[1]), self.cursor.fetchall()))
|
|
|
|
def get_online_hosts(self) -> List[OnlineHost]:
|
|
self.cursor.execute("SELECT id, https_url FROM hosts WHERE last_health_check > NOW() - INTERVAL '20 seconds'")
|
|
return list(map(lambda x: OnlineHost(id=x[0], url=x[1]), self.cursor.fetchall()))
|
|
|
|
def list_all_operations(self):
|
|
self.cursor.execute("""
|
|
SELECT operations.id, operations.email, operations.created, operations.payload,
|
|
host_operation.host host_operation.assignment_status, host_operation.assigned,
|
|
host_operation.completed, host_operation.results FROM operations
|
|
JOIN host_operation ON host_operation.operation = operations.id
|
|
""")
|
|
|
|
operations = dict()
|
|
for row in self.cursor.fetchall():
|
|
if row[0] not in operations:
|
|
operations[row[0]] = dict(email=row[1], created=row[2], payload=row[3], hosts=[])
|
|
|
|
operations[row[0]]["hosts"].append(dict(
|
|
host=row[4], assignment_status=row[5], assigned=row[6],
|
|
completed=row[7], results=row[8],
|
|
))
|
|
|
|
return operations
|
|
|
|
def create_operation(self, online_hosts: List[OnlineHost], email: str, payload: str) -> int:
|
|
self.cursor.execute( "INSERT INTO operations (email, payload) VALUES (%s, %s) RETURNING id", (email, payload) )
|
|
operation_id = self.cursor.fetchone()[0]
|
|
|
|
for host in online_hosts:
|
|
self.cursor.execute( "INSERT INTO host_operation (host, operation) VALUES (%s, %s)", (host.id, operation_id) )
|
|
|
|
self.connection.commit()
|
|
return operation_id
|
|
|
|
def update_operation(self, operation_id: int, payload: str):
|
|
self.cursor.execute(
|
|
"UPDATE operations SET payload = %s WHERE id = %s",
|
|
(payload, operation_id)
|
|
)
|
|
self.connection.commit()
|
|
|
|
def update_host_operation(self, host_id: str, operation_id: int, assignment_status: str, result: str):
|
|
if assignment_status and not result:
|
|
self.cursor.execute(
|
|
"UPDATE host_operation SET assignment_status = %s, assigned = NOW() WHERE host = %s AND operation = %s",
|
|
(assignment_status, host_id, operation_id)
|
|
)
|
|
elif not assignment_status and result:
|
|
self.cursor.execute(
|
|
"UPDATE host_operation SET results = %s, completed = NOW() WHERE host = %s AND operation = %s",
|
|
(result, host_id, operation_id)
|
|
)
|
|
elif assignment_status and result:
|
|
self.cursor.execute(
|
|
"UPDATE host_operation SET assignment_status = %s, assigned = NOW(), results = %s, completed = NOW() WHERE host = %s AND operation = %s",
|
|
(assignment_status, result, host_id, operation_id)
|
|
)
|
|
self.connection.commit()
|
|
|
|
def host_of_capsul(self, capsul_id: str) -> OnlineHost:
|
|
self.cursor.execute("SELECT hosts.id, hosts.https_url from vms JOIN hosts on hosts.id = vms.host where vms.id = %s", (capsul_id,))
|
|
row = self.cursor.fetchone()
|
|
if row:
|
|
return OnlineHost(row[0], row[1])
|
|
else:
|
|
return None
|
|
|
|
def get_payload_json_from_host_operation(self, operation_id: int, host_id: str) -> str:
|
|
self.cursor.execute(
|
|
"""
|
|
SELECT operations.payload FROM operations
|
|
JOIN host_operation ON host_operation.operation = operations.id
|
|
WHERE host_operation.host = %s AND host_operation.operation = %s
|
|
""",
|
|
(host_id, operation_id)
|
|
)
|
|
row = self.cursor.fetchone()
|
|
if row:
|
|
return row[0]
|
|
else:
|
|
return None
|
|
|
|
def claim_operation(self, operation_id: int, host_id: str) -> bool:
|
|
# have to make a new cursor to set isolation level
|
|
# cursor = self.connection.cursor()
|
|
# self.cursor.execute("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
|
|
# psycopg2.errors.ActiveSqlTransaction: SET TRANSACTION ISOLATION LEVEL must be called before any query
|
|
self.cursor.execute("""
|
|
BEGIN TRANSACTION;
|
|
UPDATE host_operation SET assignment_status = 'assigned'
|
|
WHERE host = %s AND operation = %s AND operation != (
|
|
SELECT COALESCE(
|
|
(SELECT operation FROM host_operation WHERE operation = %s AND assignment_status = 'assigned'),
|
|
-1
|
|
) as already_assigned_operation_id
|
|
);
|
|
COMMIT TRANSACTION;
|
|
""", (host_id, operation_id, operation_id))
|
|
|
|
to_return = self.cursor.rowcount != 0
|
|
|
|
self.connection.commit()
|
|
#cursor.close()
|
|
|
|
return to_return
|