diff --git a/compliance-monitor/monitor.py b/compliance-monitor/monitor.py index 0d9ee9823..9d759d2b0 100755 --- a/compliance-monitor/monitor.py +++ b/compliance-monitor/monitor.py @@ -42,6 +42,7 @@ db_get_keys, db_insert_report, db_get_recent_results2, db_patch_approval2, db_get_report, db_ensure_schema, db_get_apikeys, db_update_apikey, db_filter_apikeys, db_clear_delegates, db_find_subjects, db_insert_result2, db_get_relevant_results2, db_add_delegate, db_get_group, + db_filter_accounts, ) @@ -225,10 +226,12 @@ def import_bootstrap(bootstrap_path, conn): if not accounts and not subjects: return with conn.cursor() as cur: + accountids = [] for account in accounts: roles = sum(ROLES[r] for r in account.get('roles', ())) acc_record = {'subject': account['subject'], 'roles': roles, 'group': account.get('group')} accountid = db_update_account(cur, acc_record) + accountids.append(accountid) db_clear_delegates(cur, accountid) for delegate in account.get('delegates', ()): db_add_delegate(cur, accountid, delegate) @@ -236,6 +239,7 @@ def import_bootstrap(bootstrap_path, conn): db_filter_apikeys(cur, accountid, lambda keyid, *_: keyid in keyids) keyids = set(db_update_publickey(cur, accountid, key) for key in account.get("keys", ())) db_filter_publickeys(cur, accountid, lambda keyid, *_: keyid in keyids) + db_filter_accounts(cur, lambda accountid, *_: accountid in accountids) conn.commit() diff --git a/compliance-monitor/sql.py b/compliance-monitor/sql.py index 2d39635a6..1c8f9b07c 100644 --- a/compliance-monitor/sql.py +++ b/compliance-monitor/sql.py @@ -194,7 +194,7 @@ def db_upgrade_schema(conn: connection, cur: cursor): # that way just in case we want to use another database at some point while True: current = db_get_schema_version(cur) - if current == SCHEMA_VERSIONS[-1]: + if current >= SCHEMA_VERSIONS[-1]: # bail if version is too new (but hope it's compatible) break if current is None: # this is an empty db, but it also used to be the case with v1 @@ -252,6 +252,14 @@ def db_update_account(cur: cursor, record: dict): return accountid +def db_filter_accounts(cur: cursor, predicate: callable): + cur.execute('SELECT accountid FROM account;') + removeids = [row[0] for row in cur.fetchall() if not predicate(*row)] + while removeids: + cur.execute('DELETE FROM account WHERE accountid IN %s', (tuple(removeids[:10]), )) + del removeids[:10] + + def db_clear_delegates(cur: cursor, accountid): cur.execute('''DELETE FROM delegation WHERE accountid = %s;''', (accountid, ))