I need to delete the association table rows for many to many relationships when I delete one, but the default behavior (to remove those rows) does not seem to work in my case.
I have multiple levels of many to many relationships, as you can see in the example I'll provide below and when I delete a "parent" afterwards I try to clean up any children left behind that have no other parents. However, these children are in many to many relationships with other children and that's when the ORM fails to attempt to remove those children from their related association tables (at least, in a way that I expect).
The desired effect, of course, is that the rows in test_chain_var_region that reference the deleted chains removed. I've tried several strategies to do this but with no change in this behavior.
import pytest
from sqlalchemy import (
Table,
Column,
Integer,
String,
ForeignKey,
create_engine,
)
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.orm import relationship, Session
from sqlalchemy.util import OrderedSet
Base: DeclarativeMeta = declarative_base()
engine = create_engine(
"postgresql://postgres:postgres@localhost:5432/espresso", echo=True
)
test_chain_const_region = Table(
"test_chain_const_region",
Base.metadata,
Column("chain_id", Integer, ForeignKey("
test_chain.id")),
Column("const_region_id", Integer, ForeignKey("
test_const_region.id")),
)
test_chain_var_region = Table(
"test_chain_var_region",
Base.metadata,
Column("chain_id", Integer, ForeignKey("
test_chain.id")),
Column("var_region_id", Integer, ForeignKey("
test_var_region.id")),
)
test_molecule_chain = Table(
"test_molecule_chain",
Base.metadata,
Column("molecule_id", Integer, ForeignKey("
test_molecule.id")),
Column("chain_id", Integer, ForeignKey("
test_chain.id")),
)
test_mol_sequence_feat_mol_sequence = Table(
"test_mol_sequence_feat_mol_sequence",
Base.metadata,
Column("mol_sequence_feat_id", Integer, ForeignKey("
test_mol_sequence_feat.id")),
Column("mol_sequence_id", Integer, ForeignKey("
test_mol_sequence.id")),
)
class TestMolecule(Base):
__tablename__ = "test_molecule"
id = Column(Integer, primary_key=True)
label = Column(String)
chains = relationship(
"TestChain",
secondary=test_molecule_chain,
collection_class=OrderedSet,
back_populates="molecules",
)
class TestMolSequence(Base):
__tablename__ = "test_mol_sequence"
id = Column(Integer, primary_key=True)
content = Column(String, nullable=False, unique=True)
parent_features = relationship(
"TestMolSequenceFeat",
secondary=test_mol_sequence_feat_mol_sequence,
collection_class=OrderedSet,
back_populates="feature_sequences",
single_parent=True,
)
chains = relationship(
"TestChain", back_populates="mol_sequence", collection_class=OrderedSet
)
class TestMolSequenceFeat(Base):
__tablename__ = "test_mol_sequence_feat"
id = Column(Integer, primary_key=True)
molecule_sequence_id = Column(
Integer, ForeignKey("
test_mol_sequence.id", ondelete="CASCADE"),
)
molecule_sequence = relationship("TestMolSequence",)
start = Column(Integer)
stop = Column(Integer)
feature_sequences = relationship(
"TestMolSequence",
secondary=test_mol_sequence_feat_mol_sequence,
collection_class=OrderedSet,
back_populates="parent_features",
# single_parent=True,
)
class TestChain(Base):
__tablename__ = "test_chain"
id = Column(Integer, primary_key=True)
label = Column(String)
chain_type = Column(String)
mol_sequence_id = Column(Integer, ForeignKey("
test_mol_sequence.id"))
mol_sequence = relationship("TestMolSequence", back_populates="chains")
molecules = relationship(
"TestMolecule",
secondary=test_molecule_chain,
collection_class=OrderedSet,
back_populates="chains",
)
var_regions = relationship(
"TestVarRegion",
secondary=test_chain_var_region,
collection_class=OrderedSet,
back_populates="chains",
)
const_regions = relationship(
"TestConstRegion",
secondary=test_chain_const_region,
collection_class=OrderedSet,
back_populates="chains",
)
class TestVarRegion(Base):
__tablename__ = "test_var_region"
id = Column(Integer, primary_key=True)
molecule_sequence_id = Column(
Integer, ForeignKey("
test_mol_sequence.id", ondelete="CASCADE"),
)
description = Column(String)
additional_information = Column(String)
label = Column("label", String, nullable=True, unique=False)
molecule_sequence = relationship("TestMolSequence")
chains = relationship(
"TestChain",
secondary=test_chain_var_region,
collection_class=OrderedSet,
back_populates="var_regions",
passive_deletes=True,
)
class TestConstRegion(Base):
__tablename__ = "test_const_region"
id = Column(Integer, primary_key=True)
molecule_sequence_id = Column(
Integer, ForeignKey("
test_mol_sequence.id", ondelete="CASCADE"),
)
description = Column(String)
additional_information = Column(String)
label = Column("label", String, nullable=True, unique=False)
molecule_sequence = relationship("TestMolSequence")
chains = relationship(
"TestChain",
secondary=test_chain_const_region,
collection_class=OrderedSet,
back_populates="const_regions",
passive_deletes=True,
)
class TestManyToMany:
@pytest.fixture
def engine(self):
return create_engine(
"postgresql://postgres:postgres@localhost:5432/espresso", echo=True
)
@pytest.fixture
def session(self):
"""Returns an sqlalchemy session, and after the test tears down everything properly."""
connection = engine.connect()
# begin the nested transaction
transaction = connection.begin()
# use the connection with the already started transaction
session = Session(bind=connection)
yield session
session.close()
# roll back the broader transaction
transaction.commit()
# put back the connection to the connection pool
connection.close()
@pytest.mark.create_m2m_models
def test_create_m2m_models(self, engine):
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
@pytest.mark.seed_m2m_data
def test_m2m_seeding_data(self, engine, session):
molecule1 = TestMolecule(label="molecule1")
molecule2 = TestMolecule(label="molecule2")
molecule3 = TestMolecule(label="molecule3")
molecule4 = TestMolecule(label="molecule4")
light_chain_1_sequence = TestMolSequence(content="taglconst1VAR1")
heavy_chain_1_sequence = TestMolSequence(content="tagheavyconstant1VAR2")
heavy_chain_2_sequence = TestMolSequence(content="tagheavyconstant2VAR2")
light_chain_2_sequence = TestMolSequence(content="taglconst1VAR3")
heavy_chain_1 = TestChain(
chain_type="heavy", mol_sequence=light_chain_1_sequence
)
light_chain_1 = TestChain(
chain_type="light", mol_sequence=heavy_chain_1_sequence
)
light_chain_2 = TestChain(
chain_type="light", mol_sequence=light_chain_2_sequence
)
heavy_chain_2 = TestChain(
chain_type="heavy", mol_sequence=heavy_chain_2_sequence
)
molecule1.chains.add(heavy_chain_1)
molecule1.chains.add(light_chain_1)
molecule2.chains.add(heavy_chain_2)
molecule2.chains.add(light_chain_2)
molecule3.chains.add(heavy_chain_1)
molecule3.chains.add(light_chain_2)
molecule4.chains.add(heavy_chain_2)
molecule4.chains.add(light_chain_1)
tag_sequence = TestMolSequence(content="tag")
light_constant_region_seq = TestMolSequence(content="lconst1")
heavy_constant_region_1_seq = TestMolSequence(content="heavyconstant1")
heavy_constant_region_2_seq = TestMolSequence(content="heavyconstant2")
vr1_seq = TestMolSequence(content="VAR1")
vr2_seq = TestMolSequence(content="VAR2")
vr3_seq = TestMolSequence(content="VAR3")
# lc2_const_region_seq = TestMolSequence(content="lconst")
lc1_tag_feature = TestMolSequenceFeat(
start=0, stop=3, molecule_sequence=light_chain_1_sequence
)
lc1_const_region_feature = TestMolSequenceFeat(
start=3, stop=10, molecule_sequence=light_chain_1_sequence
)
lc1_var_region_feature = TestMolSequenceFeat(
start=10, stop=14, molecule_sequence=light_chain_1_sequence
)
hc1_tag_feature = TestMolSequenceFeat(
start=0, stop=3, molecule_sequence=heavy_chain_1_sequence
)
hc1_const_region_feature = TestMolSequenceFeat(
start=3, stop=17, molecule_sequence=heavy_chain_1_sequence
)
hc1_var_region_feature = TestMolSequenceFeat(
start=17, stop=21, molecule_sequence=heavy_chain_1_sequence
)
hc2_tag_feature = TestMolSequenceFeat(
start=0, stop=3, molecule_sequence=heavy_chain_2_sequence
)
hc2_const_region_feature = TestMolSequenceFeat(
start=3, stop=17, molecule_sequence=heavy_chain_2_sequence
)
hc2_var_region_feature = TestMolSequenceFeat(
start=17, stop=21, molecule_sequence=heavy_chain_2_sequence
)
lc2_tag_feature = TestMolSequenceFeat(
start=0, stop=3, molecule_sequence=light_chain_2_sequence
)
lc2_const_region_feature = TestMolSequenceFeat(
start=3, stop=10, molecule_sequence=light_chain_2_sequence
)
lc2_var_region_feature = TestMolSequenceFeat(
start=10, stop=14, molecule_sequence=light_chain_2_sequence
)
var_region1 = TestVarRegion(molecule_sequence=vr1_seq)
var_region2 = TestVarRegion(molecule_sequence=vr2_seq)
var_region3 = TestVarRegion(molecule_sequence=vr3_seq)
const_region1 = TestConstRegion(molecule_sequence=light_constant_region_seq)
const_region2 = TestConstRegion(molecule_sequence=heavy_constant_region_1_seq)
const_region3 = TestConstRegion(molecule_sequence=heavy_constant_region_2_seq)
light_chain_1.var_regions.add(var_region1)
heavy_chain_1.var_regions.add(var_region2)
heavy_chain_2.var_regions.add(var_region2)
light_chain_2.var_regions.add(var_region3)
light_chain_1.const_regions.add(const_region1)
light_chain_2.const_regions.add(const_region1)
heavy_chain_1.const_regions.add(const_region2)
heavy_chain_2.const_regions.add(const_region3)
lc1_tag_feature.feature_sequences.add(tag_sequence)
lc1_var_region_feature.feature_sequences.add(vr1_seq)
lc1_const_region_feature.feature_sequences.add(light_constant_region_seq)
hc1_tag_feature.feature_sequences.add(tag_sequence)
hc1_var_region_feature.feature_sequences.add(vr2_seq)
hc1_const_region_feature.feature_sequences.add(heavy_constant_region_1_seq)
lc2_tag_feature.feature_sequences.add(tag_sequence)
lc2_var_region_feature.feature_sequences.add(vr3_seq)
lc2_const_region_feature.feature_sequences.add(light_constant_region_seq)
hc2_tag_feature.feature_sequences.add(tag_sequence)
hc2_var_region_feature.feature_sequences.add(vr2_seq)
hc2_const_region_feature.feature_sequences.add(heavy_constant_region_2_seq)
session.add_all(
[
heavy_chain_1,
light_chain_1,
light_chain_2,
heavy_chain_2,
molecule1,
molecule2,
molecule3,
molecule4,
# tag_sequence,
# lc1_tag_feature,
# lc1_const_region_feature,
# hc1_tag_feature,
# hc1_var_region_feature,
# lc1_tag_feature,
# lc2_const_region_feature,
]
)
session.commit()
@pytest.mark.delete_test_m2m_models
def test_create_m2m(self, session):
molecule = session.query(TestMolecule).filter_by(label="molecule1").one()
session.delete(molecule)
session.query(TestChain).filter(~TestChain.molecules.any()).delete(
synchronize_session="fetch"
)
session.expire_all()
orphan_chains = (
session.query(TestChain).filter(~TestChain.molecules.any()).all()
)
for chain in orphan_chains:
session.delete(chain)
orphan_vrs = (
session.query(TestVarRegion).filter(~TestVarRegion.chains.any()).all()
)
for orphan_vr in orphan_vrs:
session.delete(orphan_vr)
orphan_crs = (
session.query(TestConstRegion)
.filter(~TestConstRegion.chains.any())
.all()
)
for orphan_cr in orphan_crs:
session.delete(orphan_cr)
orphan_sequences = (
session.query(TestMolSequence).filter(~TestMolSequence.chains.any()).all()
)
orphan_sequence: TestMolSequence
for orphan_sequence in orphan_sequences:
session.delete(orphan_sequence)
session.commit()