[HOME]

Path : /opt/alt/python38/lib/python3.8/site-packages/peewee_migrate/
Upload :
Current File : //opt/alt/python38/lib/python3.8/site-packages/peewee_migrate/auto.py

from collections import Hashable, OrderedDict

import peewee as pw
from playhouse.reflection import Column as VanilaColumn


INDENT = '    '
NEWLINE = '\n' + INDENT


def fk_to_params(field):
    params = {}
    if field.on_delete is not None:
        params['on_delete'] = field.on_delete
    if field.on_update is not None:
        params['on_update'] = field.on_update
    return params


FIELD_TO_PARAMS = {
    pw.CharField: lambda f: {'max_length': f.max_length},
    pw.DecimalField: lambda f: {
        'max_digits': f.max_digits, 'decimal_places': f.decimal_places,
        'auto_round': f.auto_round, 'rounding': f.rounding},
    pw.ForeignKeyField: fk_to_params,
}


class Column(VanilaColumn):

    def __init__(self, field, migrator=None):  # noqa
        self.name = field.name
        self.field_class = type(field)
        self.nullable = field.null
        self.primary_key = field.primary_key
        self.db_column = field.db_column
        self.index = field.index
        self.unique = field.unique
        self.params = {}
        if field.default is not None and not callable(field.default):
            self.params['default'] = field.default

        if self.field_class in FIELD_TO_PARAMS:
            self.params.update(FIELD_TO_PARAMS[self.field_class](field))

        self.rel_model = None
        self.related_name = None
        self.to_field = None

        if isinstance(field, pw.ForeignKeyField):
            self.to_field = field.to_field.name
            if migrator and field.rel_model._meta.name in migrator.orm:
                self.rel_model = "migrator.orm['%s']" % field.rel_model._meta.name
            else:
                self.rel_model = field.rel_model.__name__

    def get_field_parameters(self):
        params = super(Column, self).get_field_parameters()
        params.update({k: repr(v) for k, v in self.params.items()})
        return params

    def get_field(self, space=' '):
        # Generate the field definition for this column.
        field_params = self.get_field_parameters()
        param_str = ', '.join('%s=%s' % (k, v)
                              for k, v in sorted(field_params.items()))
        return '{name}{space}={space}pw.{classname}({params})'.format(
            name=self.name, space=space, classname=self.field_class.__name__, params=param_str)


def diff_one(model1, model2, **kwargs):
    """Find difference between Peewee models."""
    changes = []

    fields1 = model1._meta.fields
    fields2 = model2._meta.fields

    # Add fields
    names1 = set(fields1) - set(fields2)
    if names1:
        fields = [fields1[name] for name in names1]
        changes.append(create_fields(model1, *fields, **kwargs))

    # Drop fields
    names2 = set(fields2) - set(fields1)
    if names2:
        changes.append(drop_fields(model1, *names2))

    # Change fields
    fields_ = []
    nulls_ = []
    for name in set(fields1) - names1 - names2:
        field1, field2 = fields1[name], fields2[name]
        diff = compare_fields(field1, field2)
        null = diff.pop('null', None)
        if diff:
            fields_.append(field1)

        if null is not None:
            nulls_.append((name, null))

    if fields_:
        changes.append(change_fields(model1, *fields_, **kwargs))

    for name, null in nulls_:
        changes.append(change_not_null(model1, name, null))

    return changes


def diff_many(models1, models2, migrator=None, reverse=False):
    models1 = pw.sort_models_topologically(models1)
    models2 = pw.sort_models_topologically(models2)

    if reverse:
        models1 = reversed(models1)
        models2 = reversed(models2)

    models1 = OrderedDict([(m._meta.name, m) for m in models1])
    models2 = OrderedDict([(m._meta.name, m) for m in models2])

    changes = []

    # Add models
    for name in [m for m in models1 if m not in models2]:
        changes.append(create_model(models1[name], migrator=migrator))

    # Remove models
    for name in [m for m in models2 if m not in models1]:
        changes.append(remove_model(models2[name]))

    for name, model1 in models1.items():
        if name not in models2:
            continue
        changes += diff_one(model1, models2[name], migrator=migrator)

    return changes


def model_to_code(Model, **kwargs):
    template = """class {classname}(pw.Model):
{fields}
"""
    fields = INDENT + NEWLINE.join([
        field_to_code(field, **kwargs) for field in Model._meta.sorted_fields
        if not (isinstance(field, pw.PrimaryKeyField) and field.name == 'id')
    ])
    return template.format(
        classname=Model.__name__,
        fields=fields
    )


def create_model(Model, **kwargs):
    return '@migrator.create_model\n' + model_to_code(Model, **kwargs)


def remove_model(Model, **kwargs):
    return "migrator.remove_model('%s')" % Model._meta.db_table


def create_fields(Model, *fields, **kwargs):
    return "migrator.add_fields(%s'%s', %s)" % (
        NEWLINE,
        Model._meta.db_table,
        NEWLINE + (',' + NEWLINE).join([field_to_code(field, False, **kwargs) for field in fields])
    )


def drop_fields(Model, *fields, **kwargs):
    return "migrator.remove_fields('%s', %s)" % (
        Model._meta.db_table, ', '.join(map(repr, fields))
    )


def field_to_code(field, space=True, **kwargs):
    col = Column(field, **kwargs)
    return col.get_field(' ' if space else '')


def compare_fields(field1, field2, **kwargs):
    field_cls1, field_cls2 = type(field1), type(field2)
    if field_cls1 != field_cls2:  # noqa
        return {'cls': True}

    params1 = field_to_params(field1)
    params1['null'] = field1.null
    params2 = field_to_params(field2)
    params2['null'] = field2.null

    return dict(set(params1.items()) - set(params2.items()))


def field_to_params(field, **kwargs):
    params = FIELD_TO_PARAMS.get(type(field), lambda f: {})(field)
    if field.default is not None and \
            not callable(field.default) and \
            isinstance(field.default, Hashable):
        params['default'] = field.default
    return params


def change_fields(Model, *fields, **kwargs):
    return "migrator.change_fields('%s', %s)" % (
        Model._meta.db_table, (',' + NEWLINE).join([field_to_code(f, False) for f in fields])
    )


def change_not_null(Model, name, null):
    operation = 'drop_not_null' if null else 'add_not_null'
    return "migrator.%s('%s', %s)" % (operation, Model._meta.db_table, repr(name))