diff --git a/src/postgres/query/constraints/mod.rs b/src/postgres/query/constraints/mod.rs index 775f4326..c609b604 100644 --- a/src/postgres/query/constraints/mod.rs +++ b/src/postgres/query/constraints/mod.rs @@ -147,16 +147,51 @@ impl SchemaQueryBuilder { (Schema::ConstraintColumnUsage, Kcuf::TableName), (Schema::ConstraintColumnUsage, Kcuf::ColumnName), ]) + .columns(vec![ + // Extract the ordinal position of the referenced primary keys + (Schema::KeyColumnUsage, Kcuf::OrdinalPosition), + ]) .from((Schema::Schema, Schema::ReferentialConstraints)) .left_join( (Schema::Schema, Schema::ConstraintColumnUsage), Expr::col((Schema::ReferentialConstraints, RefC::ConstraintName)) .equals((Schema::ConstraintColumnUsage, Kcuf::ConstraintName)), ) + .left_join( + // Join the key_column_usage rows for the referenced primary keys + (Schema::Schema, Schema::KeyColumnUsage), + Condition::all() + .add( + Expr::col((Schema::ConstraintColumnUsage, Kcuf::ColumnName)) + .equals((Schema::KeyColumnUsage, Kcuf::ColumnName)), + ) + .add( + Expr::col(( + Schema::ReferentialConstraints, + RefC::UniqueConstraintName, + )) + .equals((Schema::KeyColumnUsage, Kcuf::ConstraintName)), + ) + .add( + Expr::col(( + Schema::ReferentialConstraints, + RefC::UniqueConstraintSchema, + )) + .equals((Schema::KeyColumnUsage, Kcuf::ConstraintSchema)), + ), + ) .take(), rcsq.clone(), - Expr::col((Schema::TableConstraints, Tcf::ConstraintName)) - .equals((rcsq.clone(), RefC::ConstraintName)), + Condition::all() + .add( + Expr::col((Schema::TableConstraints, Tcf::ConstraintName)) + .equals((rcsq.clone(), RefC::ConstraintName)), + ) + .add( + // Only join when the referenced primary key position matches position_in_unique_constraint for the foreign key column + Expr::col((Schema::KeyColumnUsage, Kcuf::PositionInUniqueConstraint)) + .equals((rcsq.clone(), Kcuf::OrdinalPosition)), + ), ) .and_where( Expr::col((Schema::TableConstraints, Tcf::TableSchema)).eq(schema.to_string()), diff --git a/tests/live/postgres/src/main.rs b/tests/live/postgres/src/main.rs index 721c8e8c..90c89ed0 100644 --- a/tests/live/postgres/src/main.rs +++ b/tests/live/postgres/src/main.rs @@ -52,6 +52,8 @@ async fn main() { create_cakes_bakers_table(), create_lineitem_table(), create_collection_table(), + create_parent_table(), + create_child_table(), ]; for tbl_create_stmt in tbl_create_stmts.iter() { @@ -374,3 +376,51 @@ fn create_collection_table() -> TableCreateStatement { ) .to_owned() } + +fn create_parent_table() -> TableCreateStatement { + Table::create() + .table(Alias::new("parent")) + .col(ColumnDef::new(Alias::new("id1")).integer().not_null()) + .col(ColumnDef::new(Alias::new("id2")).integer().not_null()) + .primary_key( + Index::create() + .primary() + .name("parent_pkey") + .col(Alias::new("id1")) + .col(Alias::new("id2")), + ) + .to_owned() +} + +fn create_child_table() -> TableCreateStatement { + Table::create() + .table(Alias::new("child")) + .col( + ColumnDef::new(Alias::new("id")) + .integer() + .not_null() + .auto_increment(), + ) + .col( + ColumnDef::new(Alias::new("parent_id1")) + .integer() + .not_null(), + ) + .col( + ColumnDef::new(Alias::new("parent_id2")) + .integer() + .not_null(), + ) + .foreign_key( + ForeignKey::create() + .name("FK_child_parent") + .from( + Alias::new("child"), + (Alias::new("parent_id1"), Alias::new("parent_id2")), + ) + .to(Alias::new("parent"), (Alias::new("id1"), Alias::new("id2"))) + .on_delete(ForeignKeyAction::Cascade) + .on_update(ForeignKeyAction::Cascade), + ) + .to_owned() +}