OK what you're trying to do is a little hard , and yes declare_last / declare_first are useful here, because I just noticed you need to inspect the PK of the local class, not the remote one, so that has to be set up first. So here is a demo based on declare_first, this is the basic idea, either with the FK constraint or with a primary join condition:
from __future__ import annotations
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import configure_mappers
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
Base = declarative_base()
class TransitionBase(Base):
__abstract__ = True
@declared_attr
def id(cls):
return Column(Integer, primary_key=True)
state = Column(String, nullable=False, index=True)
class HasStateMachineMixin:
@staticmethod
def get_state_class() -> Type[TransitionBase]:
raise NotImplementedError()
@classmethod
def __declare_first__(cls):
dest = cls.get_state_class()
src = inspect(cls)
dest_cols = [
for pk in src.primary_key
]
# make a ForeignKeyConstraint. if you wanted to just make a
# primaryjoin, you could create it
# primaryjoin=and_(
# *[(a==foreign(b)) for a, b in zip(src.primary_key, dest_cols)])
dest.__table__.append_constraint(
ForeignKeyConstraint(dest_cols, src.primary_key)
)
# these two steps make use of the DeclarativeMeta to receive
# new columns and attributes on the fly
for dc in dest_cols:
cls.transitions = relationship(dest, order_by=dest.id.desc())
class Transition(TransitionBase):
__tablename__ = "transitions"
id = Column(Integer, primary_key=True)
class Obj(HasStateMachineMixin, Base):
__tablename__ = "obj"
id = Column(String, primary_key=True)
@staticmethod
def get_state_class() -> Type[TransitionBase]:
return Transition
class ThreePrimaryKeys(HasStateMachineMixin, Base):
__tablename__ = "three_pks"
a = Column(String, primary_key=True)
b = Column(String, primary_key=True)
c = Column(String, primary_key=True)
@staticmethod
def get_state_class() -> Type[TransitionBase]:
return Transition
# since the mappers are going to add new columns, we need to make
# sure mapper configure is triggered before we render the DDL. this
# ensures the declare_first above runs.
configure_mappers()
e = create_engine("sqlite://", echo=True)
Base.metadata.create_all(e)
s = Session(e)
s.add(
ThreePrimaryKeys(
a="a",
b="b",
c="c",
transitions=[
Transition(state="t1"),
Transition(state="t2"),
Transition(state="t3"),
],
)
)
s.add(Obj(id="one", transitions=[Transition(state="tt1")]))
s.commit()