forked from 3wordchant/capsul-flask
126 lines
4.0 KiB
Python
126 lines
4.0 KiB
Python
import psycopg2
|
|
import re
|
|
import sys
|
|
from urllib.parse import urlparse
|
|
from os import listdir
|
|
from os.path import isfile, join
|
|
from psycopg2 import pool
|
|
from flask import current_app
|
|
from flask import g
|
|
|
|
from capsulflask.db_model import DBModel
|
|
|
|
def init_app(app):
|
|
databaseUrl = urlparse(app.config['DATABASE_URL'])
|
|
|
|
app.config['PSYCOPG2_CONNECTION_POOL'] = psycopg2.pool.SimpleConnectionPool(
|
|
1,
|
|
20,
|
|
user = databaseUrl.username,
|
|
password = databaseUrl.password,
|
|
host = databaseUrl.hostname,
|
|
port = databaseUrl.port,
|
|
database = databaseUrl.path[1:]
|
|
)
|
|
|
|
schemaMigrations = {}
|
|
schemaMigrationsPath = join(app.root_path, 'schema_migrations')
|
|
print("loading schema migration scripts from {}".format(schemaMigrationsPath))
|
|
for filename in listdir(schemaMigrationsPath):
|
|
key = re.search(r"^\d+_(up|down)", filename).group()
|
|
with open(join(schemaMigrationsPath, filename), 'rb') as file:
|
|
schemaMigrations[key] = file.read().decode("utf8")
|
|
|
|
connection = app.config['PSYCOPG2_CONNECTION_POOL'].getconn()
|
|
|
|
hasSchemaVersionTable = False
|
|
actionWasTaken = False
|
|
schemaVersion = 0
|
|
desiredSchemaVersion = 2
|
|
|
|
cursor = connection.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT table_name, table_schema FROM information_schema.tables WHERE table_schema = '{}'
|
|
""".format(app.config['DATABASE_SCHEMA']))
|
|
|
|
rows = cursor.fetchall()
|
|
for row in rows:
|
|
if row[0] == "schemaversion":
|
|
hasSchemaVersionTable = True
|
|
|
|
if hasSchemaVersionTable == False:
|
|
print("no table named schemaversion found in the {} schema. running migration 01_up".format(app.config['DATABASE_SCHEMA']))
|
|
try:
|
|
cursor.execute(schemaMigrations["01_up"])
|
|
connection.commit()
|
|
except:
|
|
print("unable to create the schemaversion table because: {}".format(my_exec_info_message(sys.exc_info())))
|
|
exit(1)
|
|
actionWasTaken = True
|
|
|
|
cursor.execute("SELECT Version FROM schemaversion")
|
|
schemaVersion = cursor.fetchall()[0][0]
|
|
|
|
if schemaVersion > desiredSchemaVersion:
|
|
print("schemaVersion ({}) > desiredSchemaVersion ({}). schema downgrades are not supported yet. exiting.".format(
|
|
schemaVersion, desiredSchemaVersion
|
|
))
|
|
exit(1)
|
|
|
|
while schemaVersion < desiredSchemaVersion:
|
|
migrationKey = "%02d_up" % (schemaVersion+1)
|
|
print("schemaVersion ({}) < desiredSchemaVersion ({}). running migration {}".format(
|
|
schemaVersion, desiredSchemaVersion, migrationKey
|
|
))
|
|
try:
|
|
cursor.execute(schemaMigrations[migrationKey])
|
|
connection.commit()
|
|
except KeyError:
|
|
print("missing schema migration script: {}_xyz.sql".format(migrationKey))
|
|
exit(1)
|
|
except:
|
|
print("unable to execute the schema migration {} because: {}".format(migrationKey, my_exec_info_message(sys.exc_info())))
|
|
exit(1)
|
|
actionWasTaken = True
|
|
|
|
schemaVersion += 1
|
|
cursor.execute("SELECT Version FROM schemaversion")
|
|
versionFromDatabase = cursor.fetchall()[0][0]
|
|
|
|
if schemaVersion != versionFromDatabase:
|
|
print("incorrect schema version value \"{}\" after running migration {}, expected \"{}\". exiting.".format(
|
|
versionFromDatabase,
|
|
migrationKey,
|
|
schemaVersion
|
|
))
|
|
exit(1)
|
|
|
|
cursor.close()
|
|
|
|
app.config['PSYCOPG2_CONNECTION_POOL'].putconn(connection)
|
|
|
|
print("{} current schemaVersion: \"{}\"".format(
|
|
("schema migration completed." if actionWasTaken else "schema is already up to date. "), schemaVersion
|
|
))
|
|
|
|
app.teardown_appcontext(close_db)
|
|
|
|
|
|
def get_model():
|
|
if 'model' not in g:
|
|
connection = current_app.config['PSYCOPG2_CONNECTION_POOL'].getconn()
|
|
cursor = connection.cursor()
|
|
g.db_model = DBModel(connection, cursor)
|
|
return g.db_model
|
|
|
|
|
|
def close_db(e=None):
|
|
model = g.pop("model", None)
|
|
|
|
if model is not None:
|
|
model.cursor.close()
|
|
current_app.config['PSYCOPG2_CONNECTION_POOL'].putconn(model.connection)
|
|
|
|
def my_exec_info_message(exec_info):
|
|
return "{}: {}".format(".".join([exec_info[0].__module__, exec_info[0].__name__]), exec_info[1]) |