initdb.py 5.43 KB
#!/usr/bin/env python3

# base
import csv
import argparse
import re
import string
from concurrent.futures import ThreadPoolExecutor

# installed packages
import bcrypt
import sqlalchemy as sa

# this project
from models import Base, Student


# ===========================================================================
#   Parse command line options
def parse_commandline_arguments():
    argparser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description='Insert new users into a database. Users can be imported from CSV files in the SIIUE format or defined in the command line. If the database does not exist, a new one is created.')

    argparser.add_argument('csvfile',
        nargs='*',
        type=str,
        default='',
        help='CSV file to import (SIIUE)')

    argparser.add_argument('--db',
        default='students.db',
        type=str,
        help='database file')

    argparser.add_argument('-A', '--admin',
        action='store_true',
        help='insert the admin user')

    argparser.add_argument('-a', '--add',
        nargs=2,
        action='append',
        metavar=('uid', 'name'),
        help='add new user')

    argparser.add_argument('-u', '--update',
        nargs='+',
        metavar='uid',
        default=[],
        help='users to update')

    argparser.add_argument('--pw',
        default=None,
        type=str,
        help='set password for new and updated users')

    argparser.add_argument('-V', '--verbose',
        action='store_true',
        help='show all students in database')

    return argparser.parse_args()


# ===========================================================================
# SIIUE names have alien strings like "(TE)" and are sometimes capitalized
# We remove them so that students dont keep asking what it means
def get_students_from_csv(filename):
    try:
        csvreader = csv.DictReader(open(filename, encoding='iso-8859-1'), delimiter=';', quotechar='"', skipinitialspace=True)
    except EnvironmentError:
        print(f'!!! Error. File "{filename}" not found !!!')
        students = []
    else:
        students = [{
            'uid': s['N.º'],
            'name': string.capwords(re.sub(r'\(.*\)', '', s['Nome']).strip())
            } for s in csvreader]

    return students


# ===========================================================================
# replace password by hash for a single student
def hashpw(student, pw=None):
    print('.', end='', flush=True)
    pw = (pw or student.get('pw', None) or student['uid']).encode('utf-8')
    student['pw'] = bcrypt.hashpw(pw, bcrypt.gensalt())


# ===========================================================================
def insert_students_into_db(session, students):
    try:
        # --- start db session ---
        session.add_all([Student(id=s['uid'], name=s['name'], password=s['pw'])
            for s in students])

        session.commit()

    except sa.exc.IntegrityError:
        print('!!! Integrity error. User(s) already in database. None inserted !!!\n')
        session.rollback()


# ===========================================================================
def show_students_in_database(session, verbose=False):
    try:
        users = session.query(Student).order_by(Student.id).all()
    except:
        raise
    else:
        n = len(users)
        print(f'Registered users:')
        if n == 0:
            print('  -- none --')
        else:
            if verbose:
                for u in users:
                    print(f'{u.id:>12}   {u.name}')
            else:
                print(f'{users[0].id:>12}   {users[0].name}')
                if n > 1:
                    print(f'{users[1].id:>12}   {users[1].name}')
                if n > 3:
                    print('           |   |')
                if n > 2:
                    print(f'{users[-1].id:>12}   {users[-1].name}')
        print(f'Total: {n}.')


# ===========================================================================
if __name__=='__main__':
    args = parse_commandline_arguments()

    # --- make list of students to insert/update
    students = []

    for csvfile in args.csvfile:
        print('Adding users from:', csvfile)
        students.extend(get_students_from_csv(csvfile))

    if args.admin:
        print('Adding user: 0, Admin.')
        students.append({'uid': '0', 'name': 'Admin'})

    if args.add:
        for uid, name in args.add:
            print(f'Adding user: {uid}, {name}.')
            students.append({'uid': uid, 'name': name})

    # --- password hashing
    if students:
        print(f'Generating password hashes (bcrypt).')
        hash_func = lambda s: hashpw(s, args.pw)
        with ThreadPoolExecutor() as executor:
            executor.map(hash_func, students)  # hashing in parallel

        print()

    # --- database stuff
    print(f'Using database: ', args.db)
    engine = sa.create_engine(f'sqlite:///{args.db}', echo=False)
    Base.metadata.create_all(engine)  # Criate schema if needed
    Session = sa.orm.sessionmaker(bind=engine)
    session = Session()

    if students:
        print(f'Inserting {len(students)}')
        insert_students_into_db(session, students)

    for s in args.update:
        print(f'Updating password of: {s}')
        u = session.query(Student).get(s)
        pw =(args.pw or s).encode('utf-8')
        u.password = bcrypt.hashpw(pw, bcrypt.gensalt())
        session.commit()

    show_students_in_database(session, args.verbose)

    session.close()