Skip to content

Commit

Permalink
Support Partial Rope
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos authored and kali committed Dec 17, 2024
1 parent 0100dc0 commit 640497b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 10 deletions.
26 changes: 19 additions & 7 deletions metal/src/rewrite_rules/apply_rope.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::rewrite_rules::BasicRotateHalf;
use crate::rewrite_rules::{previous_nodes, single_prev_node_as};
use crate::rewrite_rules::{previous_node, previous_nodes, single_prev_node_as, BasicRotateHalf};
use crate::rule_ensure;
use tract_core::internal::*;
use tract_core::ops::binary::BinMiniOp;
Expand Down Expand Up @@ -75,11 +74,24 @@ pub fn as_apply_rope_rule(
return Ok(None);
};

let apply_rope_in = rotate_half.inputs[0];
rule_ensure!(cos_mul.inputs.contains(&apply_rope_in));

let cos =
if cos_mul.inputs[0] == apply_rope_in { cos_mul.inputs[1] } else { cos_mul.inputs[0] };
// If cos and rotate half don't share the same input, we check if they don't
// input node that are the same.
let (apply_rope_in, cos) = if !cos_mul.inputs.contains(&rotate_half.inputs[0]) {
let Some(rotate_half_prev) = previous_node(model, rotate_half) else { return Ok(None) };
let Some((cos_common_input_idx, _)) = previous_nodes(model, cos_mul)
.iter()
.enumerate()
.find(|(_, n)| n.same_as(rotate_half_prev))
else {
return Ok(None);
};
(rotate_half.inputs[0], cos_mul.inputs[1 - cos_common_input_idx])
} else {
let apply_rope_in = rotate_half.inputs[0];
let cos =
if cos_mul.inputs[0] == apply_rope_in { cos_mul.inputs[1] } else { cos_mul.inputs[0] };
(apply_rope_in, cos)
};

let sin = sin_mul.inputs[1 - rotate_half_in_idx];

Expand Down
20 changes: 17 additions & 3 deletions metal/src/rewrite_rules/rotate_half.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ pub fn as_rotate_half_rule(
let Some(pos_half_op) = pos_half.op_as::<Slice>() else { return Ok(None) };

rule_ensure!(pos_half_op.axis == op.axis);
rule_ensure!(pos_half_op.start == 0.into());
rule_ensure!(pos_half_op.end == neg_half_slice_op.start);
rule_ensure!(neg_half_slice_op.end == out_fact.shape[op.axis].clone());

Expand All @@ -84,9 +83,24 @@ pub fn as_rotate_half_rule(
let Some(concatenated_last_dim) = out_fact.shape[op.axis].as_i64() else { return Ok(None) };
rule_ensure!(pos_half_slice_end * 2 == concatenated_last_dim);

let in_fact = model.node_input_facts(neg_half_slice.id)?[0];

let mut patch = TypedModelPatch::default();
let input = patch.taps(model, &neg_half_slice.inputs)?;
let out = patch.wire_node(format!("{node_name}.rotate_half"), BasicRotateHalf, &input)?;
let mut inputs = patch.taps(model, &neg_half_slice.inputs)?;

if pos_half_op.start != 0.into() || neg_half_slice_op.end != in_fact.shape[op.axis] {
inputs = patch.wire_node(
format!("{node_name}.rotate_half.slice"),
Slice {
start: pos_half_op.start.clone(),
end: neg_half_slice_op.end.clone(),
axis: op.axis,
},
&inputs,
)?;
}

let out = patch.wire_node(format!("{node_name}.rotate_half"), BasicRotateHalf, &inputs)?;
patch.shunt_outside(model, node.id.into(), out[0])?;

Ok(Some(patch))
Expand Down

0 comments on commit 640497b

Please sign in to comment.