# noinspection PyProtectedMember
from sqlalchemy import exc, util
from sqlalchemy.sql.base import _bind_or_error
from sqlalchemy.sql.ddl import (
AddConstraint,
CreateIndex,
CreateSequence,
CreateTable,
DropConstraint,
DropIndex,
DropSequence,
DropTable,
SchemaDropper,
SchemaGenerator,
SetColumnComment,
SetTableComment,
sort_tables_and_constraints,
)
from sqlalchemy.types import SchemaType
[docs]class AsyncVisitor:
[docs] async def traverse_single(self, obj, **kw):
# noinspection PyUnresolvedReferences
for v in self._visitor_iterator:
meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
if meth:
return await meth(obj, **kw)
[docs]class AsyncSchemaGenerator(AsyncVisitor, SchemaGenerator):
async def _can_create_table(self, table):
self.dialect.validate_identifier(table.name)
effective_schema = self.connection.schema_for_object(table)
if effective_schema:
self.dialect.validate_identifier(effective_schema)
return (not self.checkfirst or
not (await self.dialect.has_table(self.connection,
table.name,
schema=effective_schema)))
async def _can_create_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
return self.dialect.supports_sequences and (
(not self.dialect.sequences_optional or
not sequence.optional) and
(
not self.checkfirst or
not await self.dialect.has_sequence(
self.connection,
sequence.name,
schema=effective_schema)
)
)
[docs] async def visit_table(
self, table, create_ok=False,
include_foreign_key_constraints=None,
_is_metadata_operation=False):
if not create_ok and not await self._can_create_table(table):
return
await _Async(table.dispatch.before_create)(
table, self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
_is_metadata_operation=_is_metadata_operation)
for column in table.columns:
if column.default is not None:
await self.traverse_single(column.default)
if not self.dialect.supports_alter:
# e.g., don't omit any foreign key constraints
include_foreign_key_constraints = None
await self.connection.status(
CreateTable(
table,
include_foreign_key_constraints=include_foreign_key_constraints
))
if hasattr(table, 'indexes'):
for index in table.indexes:
await self.traverse_single(index)
if self.dialect.supports_comments and not self.dialect.inline_comments:
if table.comment is not None:
await self.connection.status(SetTableComment(table))
for column in table.columns:
if column.comment is not None:
await self.connection.status(SetColumnComment(column))
await _Async(table.dispatch.after_create)(
table, self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
_is_metadata_operation=_is_metadata_operation)
[docs] async def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
return
await self.connection.status(AddConstraint(constraint))
[docs] async def visit_sequence(self, sequence, create_ok=False):
if not create_ok and not await self._can_create_sequence(sequence):
return
await self.connection.status(CreateSequence(sequence))
[docs] async def visit_index(self, index):
await self.connection.status(CreateIndex(index))
[docs]class AsyncSchemaDropper(AsyncVisitor, SchemaDropper):
async def _can_drop_table(self, table):
self.dialect.validate_identifier(table.name)
effective_schema = self.connection.schema_for_object(table)
if effective_schema:
self.dialect.validate_identifier(effective_schema)
return not self.checkfirst or (await self.dialect.has_table(
self.connection, table.name, schema=effective_schema))
async def _can_drop_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
return self.dialect.supports_sequences and (
(not self.dialect.sequences_optional or
not sequence.optional) and
(not self.checkfirst or
await self.dialect.has_sequence(
self.connection,
sequence.name,
schema=effective_schema))
)
[docs] async def visit_index(self, index):
await self.connection.status(DropIndex(index))
[docs] async def visit_table(self, table, drop_ok=False,
_is_metadata_operation=False):
if not drop_ok and not await self._can_drop_table(table):
return
await _Async(table.dispatch.before_drop)(
table, self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
_is_metadata_operation=_is_metadata_operation)
for column in table.columns:
if column.default is not None:
await self.traverse_single(column.default)
await self.connection.status(DropTable(table))
await _Async(table.dispatch.after_drop)(
table, self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
_is_metadata_operation=_is_metadata_operation)
[docs] async def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
return
await self.connection.status(DropConstraint(constraint))
[docs] async def visit_sequence(self, sequence, drop_ok=False):
if not drop_ok and not await self._can_drop_sequence(sequence):
return
await self.connection.status(DropSequence(sequence))
[docs]class GinoSchemaVisitor:
__slots__ = ('_item',)
def __init__(self, item):
self._item = item
[docs] async def create(self, bind=None, *args, **kwargs):
if bind is None:
bind = _bind_or_error(self._item)
await getattr(bind, '_run_visitor')(AsyncSchemaGenerator,
self._item, *args, **kwargs)
return self._item
[docs] async def drop(self, bind=None, *args, **kwargs):
if bind is None:
bind = _bind_or_error(self._item)
await getattr(bind, '_run_visitor')(AsyncSchemaDropper,
self._item, *args, **kwargs)
[docs] async def create_all(self, bind=None, tables=None, checkfirst=True):
await self.create(bind=bind, tables=tables, checkfirst=checkfirst)
[docs] async def drop_all(self, bind=None, tables=None, checkfirst=True):
await self.drop(bind=bind, tables=tables, checkfirst=checkfirst)
[docs]class AsyncSchemaTypeMixin:
[docs] async def create_async(self, bind=None, checkfirst=False):
if bind is None:
bind = _bind_or_error(self)
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
await t.create_async(bind=bind, checkfirst=checkfirst)
[docs] async def drop_async(self, bind=None, checkfirst=False):
if bind is None:
bind = _bind_or_error(self)
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
await t.drop_async(bind=bind, checkfirst=checkfirst)
async def _on_table_create_async(self, target, bind, **kw):
if not self._is_impl_for_variant(bind.dialect, kw):
return
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
await getattr(t, '_on_table_create_async')(target, bind, **kw)
async def _on_table_drop_async(self, target, bind, **kw):
if not self._is_impl_for_variant(bind.dialect, kw):
return
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
await getattr(t, '_on_table_drop_async')(target, bind, **kw)
async def _on_metadata_create_async(self, target, bind, **kw):
if not self._is_impl_for_variant(bind.dialect, kw):
return
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
await getattr(t, '_on_metadata_create_async')(target, bind, **kw)
async def _on_metadata_drop_async(self, target, bind, **kw):
if not self._is_impl_for_variant(bind.dialect, kw):
return
t = self.dialect_impl(bind.dialect)
if t.__class__ is not self.__class__ and isinstance(t, SchemaType):
await getattr(t, '_on_metadata_drop_async')(target, bind, **kw)
async def _call_portable_instancemethod(fn, args, kw):
m = getattr(fn.target, fn.name + '_async', None)
if m is None:
return fn(*args, **kw)
else:
kw.update(fn.kwargs)
return await m(*args, **kw)
class _Async:
def __init__(self, listener):
self._listener = listener
async def call(self, *args, **kw):
for fn in self._listener.parent_listeners:
await _call_portable_instancemethod(fn, args, kw)
for fn in self._listener.listeners:
await _call_portable_instancemethod(fn, args, kw)
def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)
[docs]def patch_schema(db):
for st in {'Enum'}:
setattr(db, st, type(st, (getattr(db, st), AsyncSchemaTypeMixin), {}))