import inspect
import itertools
import time
import asyncpg
from sqlalchemy import util, exc, sql
from sqlalchemy.dialects.postgresql import ( # noqa: F401
ARRAY,
CreateEnumType,
DropEnumType,
JSON,
JSONB
)
from sqlalchemy.dialects.postgresql.base import (
ENUM,
PGCompiler,
PGDialect,
PGExecutionContext,
)
from sqlalchemy.sql import sqltypes
from . import base
JSON_COLTYPE = 114
JSONB_COLTYPE = 3802
[文档]class AsyncpgDBAPI(base.BaseDBAPI):
Error = asyncpg.PostgresError, asyncpg.InterfaceError
[文档]class AsyncpgCompiler(PGCompiler):
@property
def bindtemplate(self):
return self._bindtemplate
@bindtemplate.setter
def bindtemplate(self, val):
# noinspection PyAttributeOutsideInit
self._bindtemplate = val.replace(':', '$')
def _apply_numbered_params(self):
if hasattr(self, 'string'):
return super()._apply_numbered_params()
# noinspection PyAbstractClass
[文档]class AsyncpgExecutionContext(base.ExecutionContextOverride,
PGExecutionContext):
async def _execute_scalar(self, stmt, type_):
conn = self.root_connection
if isinstance(stmt, util.text_type) and \
not self.dialect.supports_unicode_statements:
stmt = self.dialect._encoder(stmt)[0]
if self.dialect.positional:
default_params = self.dialect.execute_sequence_format()
else:
default_params = {}
conn._cursor_execute(self.cursor, stmt, default_params, context=self)
r = await self.cursor.async_execute(stmt, None, default_params, 1)
r = r[0][0]
if type_ is not None:
# apply type post processors to the result
proc = type_._cached_result_processor(
self.dialect,
self.cursor.description[0][1]
)
if proc:
return proc(r)
return r
[文档]class AsyncpgIterator:
def __init__(self, context, iterator):
self._context = context
self._iterator = iterator
async def __anext__(self):
row = await self._iterator.__anext__()
return self._context.process_rows([row])[0]
[文档]class AsyncpgCursor(base.Cursor):
def __init__(self, context, cursor):
self._context = context
self._cursor = cursor
async def many(self, n, *, timeout=base.DEFAULT):
if timeout is base.DEFAULT:
timeout = self._context.timeout
rows = await self._cursor.fetch(n, timeout=timeout)
return self._context.process_rows(rows)
async def next(self, *, timeout=base.DEFAULT):
if timeout is base.DEFAULT:
timeout = self._context.timeout
row = await self._cursor.fetchrow(timeout=timeout)
if not row:
return None
return self._context.process_rows([row])[0]
async def forward(self, n, *, timeout=base.DEFAULT):
if timeout is base.DEFAULT:
timeout = self._context.timeout
await self._cursor.forward(n, timeout=timeout)
[文档]class PreparedStatement(base.PreparedStatement):
def __init__(self, prepared, clause=None):
super().__init__(clause)
self._prepared = prepared
def _get_iterator(self, *params, **kwargs):
return AsyncpgIterator(
self.context, self._prepared.cursor(*params, **kwargs).__aiter__())
async def _get_cursor(self, *params, **kwargs):
iterator = await self._prepared.cursor(*params, **kwargs)
return AsyncpgCursor(self.context, iterator)
async def _execute(self, params, one):
if one:
rv = await self._prepared.fetchrow(*params)
if rv is None:
rv = []
else:
rv = [rv]
else:
rv = await self._prepared.fetch(*params)
return self._prepared.get_statusmsg(), rv
[文档]class DBAPICursor(base.DBAPICursor):
def __init__(self, dbapi_conn):
self._conn = dbapi_conn
self._attributes = None
self._status = None
async def prepare(self, context, clause=None):
timeout = context.timeout
if timeout is None:
conn = await self._conn.acquire(timeout=timeout)
else:
before = time.monotonic()
conn = await self._conn.acquire(timeout=timeout)
after = time.monotonic()
timeout -= after - before
prepared = await conn.prepare(context.statement, timeout=timeout)
try:
self._attributes = prepared.get_attributes()
except TypeError: # asyncpg <= 0.12.0
self._attributes = []
rv = PreparedStatement(prepared, clause)
rv.context = context
return rv
async def async_execute(self, query, timeout, args, limit=0, many=False):
if timeout is None:
conn = await self._conn.acquire(timeout=timeout)
else:
before = time.monotonic()
conn = await self._conn.acquire(timeout=timeout)
after = time.monotonic()
timeout -= after - before
_protocol = getattr(conn, '_protocol')
timeout = getattr(_protocol, '_get_timeout')(timeout)
def executor(state, timeout_):
if many:
return _protocol.bind_execute_many(state, args, '', timeout_)
else:
return _protocol.bind_execute(state, args, '', limit, True,
timeout_)
with getattr(conn, '_stmt_exclusive_section'):
result, stmt = await getattr(conn, '_do_execute')(
query, executor, timeout)
try:
self._attributes = getattr(stmt, '_get_attributes')()
except TypeError: # asyncpg <= 0.12.0
self._attributes = []
if not many:
result, self._status = result[:2]
return result
@property
def description(self):
return [((a[0], a[1][0]) + (None,) * 5) for a in self._attributes]
[文档] def get_statusmsg(self):
return self._status.decode()
[文档]class Pool(base.Pool):
def __init__(self, url, loop, **kwargs):
self._url = url
self._loop = loop
self._kwargs = kwargs
self._pool = None
async def _init(self):
args = self._kwargs.copy()
args.update(
loop=self._loop,
host=self._url.host,
port=self._url.port,
user=self._url.username,
database=self._url.database,
password=self._url.password,
)
self._pool = await asyncpg.create_pool(**args)
return self
def __await__(self):
return self._init().__await__()
@property
def raw_pool(self):
return self._pool
async def acquire(self, *, timeout=None):
return await self._pool.acquire(timeout=timeout)
async def release(self, conn):
await self._pool.release(conn)
async def close(self):
await self._pool.close()
[文档]class NullPool(base.Pool):
# TODO: generic NullPool, abstracting connection part
def __init__(self, url, loop, **kwargs):
self._loop = loop
self._kwargs = dict()
for k in inspect.getfullargspec(asyncpg.connect).kwonlyargs:
if k in kwargs:
self._kwargs[k] = kwargs[k]
self._kwargs.update(dict(
host=url.host,
port=url.port,
user=url.username,
database=url.database,
password=url.password,
))
def __await__(self):
async def return_self():
return self
return return_self().__await__()
@property
def raw_pool(self):
return self
async def acquire(self, *, timeout=None):
args = self._kwargs.copy()
if timeout is not None:
args.update(timeout=timeout)
return await asyncpg.connect(loop=self._loop, **args)
async def release(self, conn):
await conn.close()
async def close(self):
pass
[文档]class Transaction(base.Transaction):
def __init__(self, tx):
self._tx = tx
@property
def raw_transaction(self):
return self._tx
async def begin(self):
await self._tx.start()
async def commit(self):
await self._tx.commit()
async def rollback(self):
await self._tx.rollback()
[文档]class AsyncEnum(ENUM):
async def create_async(self, bind=None, checkfirst=True):
if not checkfirst or \
not await bind.dialect.has_type(
bind, self.name, schema=self.schema):
await bind.status(CreateEnumType(self))
async def drop_async(self, bind=None, checkfirst=True):
if not checkfirst or \
await bind.dialect.has_type(bind, self.name,
schema=self.schema):
await bind.status(DropEnumType(self))
async def _on_table_create_async(self, target, bind, checkfirst=False,
**kw):
if checkfirst or (
not self.metadata and
not kw.get('_is_metadata_operation', False)) and \
not self._check_for_name_in_memos(checkfirst, kw):
await self.create_async(bind=bind, checkfirst=checkfirst)
async def _on_table_drop_async(self, target, bind, checkfirst=False, **kw):
if not self.metadata and \
not kw.get('_is_metadata_operation', False) and \
not self._check_for_name_in_memos(checkfirst, kw):
await self.drop_async(bind=bind, checkfirst=checkfirst)
async def _on_metadata_create_async(self, target, bind, checkfirst=False,
**kw):
if not self._check_for_name_in_memos(checkfirst, kw):
await self.create_async(bind=bind, checkfirst=checkfirst)
async def _on_metadata_drop_async(self, target, bind, checkfirst=False,
**kw):
if not self._check_for_name_in_memos(checkfirst, kw):
await self.drop_async(bind=bind, checkfirst=checkfirst)
[文档]class GinoNullType(sqltypes.NullType):
[文档] def result_processor(self, dialect, coltype):
if coltype == JSON_COLTYPE:
return JSON().result_processor(dialect, coltype)
if coltype == JSONB_COLTYPE:
return JSONB().result_processor(dialect, coltype)
return super().result_processor(dialect, coltype)
# noinspection PyAbstractClass
[文档]class AsyncpgDialect(PGDialect, base.AsyncDialectMixin):
driver = 'asyncpg'
supports_native_decimal = True
dbapi_class = AsyncpgDBAPI
statement_compiler = AsyncpgCompiler
execution_ctx_cls = AsyncpgExecutionContext
cursor_cls = DBAPICursor
init_kwargs = set(itertools.chain(
*[inspect.getfullargspec(f).kwonlydefaults.keys() for f in
[asyncpg.create_pool, asyncpg.connect]]))
colspecs = util.update_copy(
PGDialect.colspecs,
{
ENUM: AsyncEnum,
sqltypes.Enum: AsyncEnum,
sqltypes.NullType: GinoNullType,
}
)
def __init__(self, *args, **kwargs):
self._pool_kwargs = {}
for k in self.init_kwargs:
if k in kwargs:
self._pool_kwargs[k] = kwargs.pop(k)
super().__init__(*args, **kwargs)
self._init_mixin()
async def init_pool(self, url, loop, pool_class=None):
if pool_class is None:
pool_class = Pool
return await pool_class(url, loop, init=self.on_connect(),
**self._pool_kwargs)
# noinspection PyMethodMayBeStatic
[文档] def transaction(self, raw_conn, args, kwargs):
return Transaction(raw_conn.transaction(*args, **kwargs))
[文档] def on_connect(self):
if self.isolation_level is not None:
async def connect(conn):
await self.set_isolation_level(conn, self.isolation_level)
return connect
else:
return None
async def set_isolation_level(self, connection, level):
"""
Given an asyncpg connection, set its isolation level.
"""
level = level.replace('_', ' ')
if level not in self._isolation_lookup:
raise exc.ArgumentError(
"Invalid value '%s' for isolation_level. "
"Valid isolation levels for %s are %s" %
(level, self.name, ", ".join(self._isolation_lookup))
)
await connection.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION "
"ISOLATION LEVEL %s" % level)
await connection.execute("COMMIT")
async def get_isolation_level(self, connection):
"""
Given an asyncpg connection, return its isolation level.
"""
val = await connection.fetchval('show transaction isolation level')
return val.upper()
async def has_schema(self, connection, schema):
row = await connection.first(
sql.text(
"select nspname from pg_namespace "
"where lower(nspname)=:schema"
).bindparams(
sql.bindparam(
'schema',
util.text_type(schema.lower()),
type_=sqltypes.Unicode,
)
)
)
return bool(row)
async def has_table(self, connection, table_name, schema=None):
# seems like case gets folded in pg_class...
if schema is None:
row = await connection.first(
sql.text(
"select relname from pg_class c join pg_namespace n on "
"n.oid=c.relnamespace where "
"pg_catalog.pg_table_is_visible(c.oid) "
"and relname=:name"
).bindparams(
sql.bindparam(
'name',
util.text_type(table_name),
type_=sqltypes.Unicode
),
)
)
else:
row = await connection.first(
sql.text(
"select relname from pg_class c join pg_namespace n on "
"n.oid=c.relnamespace where n.nspname=:schema and "
"relname=:name"
).bindparams(
sql.bindparam(
'name',
util.text_type(table_name),
type_=sqltypes.Unicode,
),
sql.bindparam(
'schema',
util.text_type(schema),
type_=sqltypes.Unicode,
)
)
)
return bool(row)
async def has_sequence(self, connection, sequence_name, schema=None):
if schema is None:
row = await connection.first(
sql.text(
"SELECT relname FROM pg_class c join pg_namespace n on "
"n.oid=c.relnamespace where relkind='S' and "
"n.nspname=current_schema() "
"and relname=:name"
).bindparams(
sql.bindparam(
'name',
util.text_type(sequence_name),
type_=sqltypes.Unicode,
)
)
)
else:
row = await connection.first(
sql.text(
"SELECT relname FROM pg_class c join pg_namespace n on "
"n.oid=c.relnamespace where relkind='S' and "
"n.nspname=:schema and relname=:name"
).bindparams(
sql.bindparam(
'name',
util.text_type(sequence_name),
type_=sqltypes.Unicode,
),
sql.bindparam(
'schema',
util.text_type(schema),
type_=sqltypes.Unicode,
)
)
)
return bool(row)
async def has_type(self, connection, type_name, schema=None):
if schema is not None:
query = """
SELECT EXISTS (
SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
WHERE t.typnamespace = n.oid
AND t.typname = :typname
AND n.nspname = :nspname
)
"""
query = sql.text(query)
else:
query = """
SELECT EXISTS (
SELECT * FROM pg_catalog.pg_type t
WHERE t.typname = :typname
AND pg_type_is_visible(t.oid)
)
"""
query = sql.text(query)
query = query.bindparams(
sql.bindparam(
'typname',
util.text_type(type_name),
type_=sqltypes.Unicode,
),
)
if schema is not None:
query = query.bindparams(
sql.bindparam(
'nspname',
util.text_type(schema),
type_=sqltypes.Unicode,
),
)
return bool(await connection.scalar(query))