1. Create tables in all binds
Observation: db.create_all() calls self.get_tables_for_bind().
Solution: Override SQLAlchemy get_tables_for_bind() to support '__all__'.
class MySQLAlchemy(SQLAlchemy):
def get_tables_for_bind(self, bind=None):
result = []
for table in self.Model.metadata.tables.values():
# if table.info.get('bind_key') == bind:
if table.info.get('bind_key') == bind or (bind is not None and table.info.get('bind_key') == '__all__'):
result.append(table)
return result
Usage:
# db = SQLAlchemy(app) # Replace this
db = MySQLAlchemy(app) # with this
db.create_all()
2. Choose a specific bind dynamically
Observation: SignallingSession get_bind() is responsible for determining the bind.
Solution:
- Override
SignallingSession get_bind() to get the bind key from some context.
- Override
SQLAlchemy create_session() to use our custom session class.
- Support the context to choose a specific bind on
db for accessibility.
- Force a context to be specified for tables with
'__all__' as bind key, by overriding SQLAlchemy get_binds() to restore the default engine.
class MySignallingSession(SignallingSession):
def __init__(self, db, *args, **kwargs):
super().__init__(db, *args, **kwargs)
self.db = db
def get_bind(self, mapper=None, clause=None):
if mapper is not None:
info = getattr(mapper.persist_selectable, 'info', {})
if info.get('bind_key') == '__all__':
info['bind_key'] = self.db.context_bind_key
try:
return super().get_bind(mapper=mapper, clause=clause)
finally:
info['bind_key'] = '__all__'
return super().get_bind(mapper=mapper, clause=clause)
class MySQLAlchemy(SQLAlchemy):
context_bind_key = None
@contextmanager
def context(self, bind=None):
_context_bind_key = self.context_bind_key
try:
self.context_bind_key = bind
yield
finally:
self.context_bind_key = _context_bind_key
def create_session(self, options):
return orm.sessionmaker(class_=MySignallingSession, db=self, **options)
def get_binds(self, app=None):
binds = super().get_binds(app=app)
# Restore default engine for table.info.get('bind_key') == '__all__'
app = self.get_app(app)
engine = self.get_engine(app, None)
tables = self.get_tables_for_bind('__all__')
binds.update(dict((table, engine) for table in tables))
return binds
def get_tables_for_bind(self, bind=None):
result = []
for table in self.Model.metadata.tables.values():
if table.info.get('bind_key') == bind or (bind is not None and table.info.get('bind_key') == '__all__'):
result.append(table)
return result
Usage:
class Patient(db.Model):
__tablename__ = "patients"
__bind_key__ = "__all__" # Add this
Test case:
with db.context(bind='clinic1'):
db.session.add(Patient())
db.session.flush() # Flush in 'clinic1'
with db.context(bind='clinic2'):
patients_count = Patient.query.filter().count()
print(patients_count) # 0 in 'clinic2'
patients_count = Patient.query.filter().count()
print(patients_count) # 1 in 'clinic1'
About foreign keys referencing the default bind
You have to specify the schema.
Limitations:
- MySQL:
- Binds must be in the same MySQL instance. Otherwise, it has to be a plain column.
- The foreign object in the default bind must already be committed.
Otherwise, when inserting an object that references it, you will get this lock error:
MySQLdb._exceptions.OperationalError: (1205, 'Lock wait timeout exceeded; try restarting transaction')
- SQLite: Foreign keys across databases are not enforced.
Usage:
# app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql://user:pass@localhost/main'
class PatientType(db.Model):
__tablename__ = "patient_types"
__table_args__ = {"schema": "main"} # Add this, based on database name
id = Column(Integer, primary_key=True)
# ...
class Patient(db.Model):
__tablename__ = "patients"
__bind_key__ = "__all__"
id = Column(Integer, primary_key=True)
# ...
# patient_type_id = Column(Integer, ForeignKey("patient_types.id")) # Replace this
patient_type_id = Column(Integer, ForeignKey("main.patient_types.id")) # with this
patient_type = relationship("PatientType")
Test case:
patient_type = PatientType.query.first()
if not patient_type:
patient_type = PatientType()
db.session.add(patient_type)
db.session.commit() # Commit to reference from other binds
with db.context(bind='clinic1'):
db.session.add(Patient(patient_type=patient_type))
db.session.flush() # Flush in 'clinic1'
with db.context(bind='clinic2'):
patients_count = Patient.query.filter().count()
print(patients_count) # 0 in 'clinic2'
patients_count = Patient.query.filter().count()
print(patients_count) # 1 in 'clinic1'