Source code for gino.schema

# noinspection PyProtectedMember
from sqlalchemy import exc, util
from sqlalchemy.sql.base import _bind_or_error
from sqlalchemy.sql.ddl import (
    AddConstraint,
    CreateIndex,
    CreateSequence,
    CreateTable,
    DropConstraint,
    DropIndex,
    DropSequence,
    DropTable,
    SchemaDropper,
    SchemaGenerator,
    SetColumnComment,
    SetTableComment,
    sort_tables_and_constraints,
)
from sqlalchemy.types import SchemaType


[docs]class AsyncVisitor: async def traverse_single(self, obj, **kw): # Support both SQLAlchemy 1.2 and 1.3 for v in getattr(self, 'visitor_iterator', getattr(self, '_visitor_iterator', None)): meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: return await meth(obj, **kw)
[docs]class AsyncSchemaGenerator(AsyncVisitor, SchemaGenerator): async def _can_create_table(self, table): self.dialect.validate_identifier(table.name) effective_schema = self.connection.schema_for_object(table) if effective_schema: self.dialect.validate_identifier(effective_schema) return (not self.checkfirst or not (await self.dialect.has_table(self.connection, table.name, schema=effective_schema))) async def _can_create_sequence(self, sequence): effective_schema = self.connection.schema_for_object(sequence) return self.dialect.supports_sequences and ( (not self.dialect.sequences_optional or not sequence.optional) and ( not self.checkfirst or not await self.dialect.has_sequence( self.connection, sequence.name, schema=effective_schema) ) ) async def visit_metadata(self, metadata): if self.tables is not None: tables = self.tables else: tables = list(metadata.tables.values()) tables_create = [] for t in tables: if await self._can_create_table(t): tables_create.append(t) collection = sort_tables_and_constraints(tables_create) seq_coll = [] # noinspection PyProtectedMember for s in metadata._sequences.values(): if s.column is None and await self._can_create_sequence(s): seq_coll.append(s) event_collection = [ t for (t, fks) in collection if t is not None ] await _Async(metadata.dispatch.before_create)( metadata, self.connection, tables=event_collection, checkfirst=self.checkfirst, _ddl_runner=self) for seq in seq_coll: await self.traverse_single(seq, create_ok=True) for table, fkcs in collection: if table is not None: await self.traverse_single( table, create_ok=True, include_foreign_key_constraints=fkcs, _is_metadata_operation=True) else: for fkc in fkcs: await self.traverse_single(fkc) await _Async(metadata.dispatch.after_create)( metadata, self.connection, tables=event_collection, checkfirst=self.checkfirst, _ddl_runner=self) async def visit_table( self, table, create_ok=False, include_foreign_key_constraints=None, _is_metadata_operation=False): if not create_ok and not await self._can_create_table(table): return await _Async(table.dispatch.before_create)( table, self.connection, checkfirst=self.checkfirst, _ddl_runner=self, _is_metadata_operation=_is_metadata_operation) for column in table.columns: if column.default is not None: await self.traverse_single(column.default) if not self.dialect.supports_alter: # e.g., don't omit any foreign key constraints include_foreign_key_constraints = None await self.connection.status( CreateTable( table, include_foreign_key_constraints=include_foreign_key_constraints )) if hasattr(table, 'indexes'): for index in table.indexes: await self.traverse_single(index) if self.dialect.supports_comments and not self.dialect.inline_comments: if table.comment is not None: await self.connection.status(SetTableComment(table)) for column in table.columns: if column.comment is not None: await self.connection.status(SetColumnComment(column)) await _Async(table.dispatch.after_create)( table, self.connection, checkfirst=self.checkfirst, _ddl_runner=self, _is_metadata_operation=_is_metadata_operation) async def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return await self.connection.status(AddConstraint(constraint)) async def visit_sequence(self, sequence, create_ok=False): if not create_ok and not await self._can_create_sequence(sequence): return await self.connection.status(CreateSequence(sequence)) async def visit_index(self, index): await self.connection.status(CreateIndex(index))
[docs]class AsyncSchemaDropper(AsyncVisitor, SchemaDropper): async def visit_metadata(self, metadata): if self.tables is not None: tables = self.tables else: tables = list(metadata.tables.values()) try: unsorted_tables = [] for t in tables: if await self._can_drop_table(t): unsorted_tables.append(t) collection = list(reversed( sort_tables_and_constraints( unsorted_tables, filter_fn=lambda constraint: False if not self.dialect.supports_alter or constraint.name is None else None ) )) except exc.CircularDependencyError as err2: if not self.dialect.supports_alter: util.warn( "Can't sort tables for DROP; an " "unresolvable foreign key " "dependency exists between tables: %s, and backend does " "not support ALTER. To restore at least a partial sort, " "apply use_alter=True to ForeignKey and " "ForeignKeyConstraint " "objects involved in the cycle to mark these as known " "cycles that will be ignored." % ( ", ".join(sorted([t.fullname for t in err2.cycles])) ) ) collection = [(t, ()) for t in unsorted_tables] else: util.raise_from_cause( exc.CircularDependencyError( err2.args[0], err2.cycles, err2.edges, msg="Can't sort tables for DROP; an unresolvable " "foreign key dependency exists between tables: %s." " Please ensure that the ForeignKey and " "ForeignKeyConstraint objects involved in the " "cycle have names so that they can be dropped " "using DROP CONSTRAINT." % ( ", ".join( sorted([t.fullname for t in err2.cycles])) ) ) ) seq_coll = [] for s in metadata._sequences.values(): if s.column is None and await self._can_drop_sequence(s): seq_coll.append(s) event_collection = [ t for (t, fks) in collection if t is not None ] await _Async(metadata.dispatch.before_drop)( metadata, self.connection, tables=event_collection, checkfirst=self.checkfirst, _ddl_runner=self) for table, fkcs in collection: if table is not None: await self.traverse_single( table, drop_ok=True, _is_metadata_operation=True) else: for fkc in fkcs: await self.traverse_single(fkc) for seq in seq_coll: await self.traverse_single(seq, drop_ok=True) await _Async(metadata.dispatch.after_drop)( metadata, self.connection, tables=event_collection, checkfirst=self.checkfirst, _ddl_runner=self) async def _can_drop_table(self, table): self.dialect.validate_identifier(table.name) effective_schema = self.connection.schema_for_object(table) if effective_schema: self.dialect.validate_identifier(effective_schema) return not self.checkfirst or (await self.dialect.has_table( self.connection, table.name, schema=effective_schema)) async def _can_drop_sequence(self, sequence): effective_schema = self.connection.schema_for_object(sequence) return self.dialect.supports_sequences and ( (not self.dialect.sequences_optional or not sequence.optional) and (not self.checkfirst or await self.dialect.has_sequence( self.connection, sequence.name, schema=effective_schema)) ) async def visit_index(self, index): await self.connection.status(DropIndex(index)) async def visit_table(self, table, drop_ok=False, _is_metadata_operation=False): if not drop_ok and not await self._can_drop_table(table): return await _Async(table.dispatch.before_drop)( table, self.connection, checkfirst=self.checkfirst, _ddl_runner=self, _is_metadata_operation=_is_metadata_operation) for column in table.columns: if column.default is not None: await self.traverse_single(column.default) await self.connection.status(DropTable(table)) await _Async(table.dispatch.after_drop)( table, self.connection, checkfirst=self.checkfirst, _ddl_runner=self, _is_metadata_operation=_is_metadata_operation) async def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return await self.connection.status(DropConstraint(constraint)) async def visit_sequence(self, sequence, drop_ok=False): if not drop_ok and not await self._can_drop_sequence(sequence): return await self.connection.status(DropSequence(sequence))
[docs]class GinoSchemaVisitor: __slots__ = ('_item',) def __init__(self, item): self._item = item async def create(self, bind=None, *args, **kwargs): if bind is None: bind = _bind_or_error(self._item) await getattr(bind, '_run_visitor')(AsyncSchemaGenerator, self._item, *args, **kwargs) return self._item async def drop(self, bind=None, *args, **kwargs): if bind is None: bind = _bind_or_error(self._item) await getattr(bind, '_run_visitor')(AsyncSchemaDropper, self._item, *args, **kwargs) async def create_all(self, bind=None, tables=None, checkfirst=True): await self.create(bind=bind, tables=tables, checkfirst=checkfirst) async def drop_all(self, bind=None, tables=None, checkfirst=True): await self.drop(bind=bind, tables=tables, checkfirst=checkfirst)
[docs]class AsyncSchemaTypeMixin: async def create_async(self, bind=None, checkfirst=False): if bind is None: bind = _bind_or_error(self) t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): await t.create_async(bind=bind, checkfirst=checkfirst) async def drop_async(self, bind=None, checkfirst=False): if bind is None: bind = _bind_or_error(self) t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): await t.drop_async(bind=bind, checkfirst=checkfirst) async def _on_table_create_async(self, target, bind, **kw): if not self._is_impl_for_variant(bind.dialect, kw): return t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): await getattr(t, '_on_table_create_async')(target, bind, **kw) async def _on_table_drop_async(self, target, bind, **kw): if not self._is_impl_for_variant(bind.dialect, kw): return t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): await getattr(t, '_on_table_drop_async')(target, bind, **kw) async def _on_metadata_create_async(self, target, bind, **kw): if not self._is_impl_for_variant(bind.dialect, kw): return t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): await getattr(t, '_on_metadata_create_async')(target, bind, **kw) async def _on_metadata_drop_async(self, target, bind, **kw): if not self._is_impl_for_variant(bind.dialect, kw): return t = self.dialect_impl(bind.dialect) if t.__class__ is not self.__class__ and isinstance(t, SchemaType): await getattr(t, '_on_metadata_drop_async')(target, bind, **kw)
async def _call_portable_instancemethod(fn, args, kw): m = getattr(fn.target, fn.name + '_async', None) if m is None: return fn(*args, **kw) else: kw.update(fn.kwargs) return await m(*args, **kw) class _Async: def __init__(self, listener): self._listener = listener async def call(self, *args, **kw): for fn in self._listener.parent_listeners: await _call_portable_instancemethod(fn, args, kw) for fn in self._listener.listeners: await _call_portable_instancemethod(fn, args, kw) def __call__(self, *args, **kwargs): return self.call(*args, **kwargs)
[docs]def patch_schema(db): for st in {'Enum'}: setattr(db, st, type(st, (getattr(db, st), AsyncSchemaTypeMixin), {}))