Skip to content

Commit

Permalink
Subpart11 for async drop (major5) - shims codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
azhogin committed Aug 29, 2024
1 parent b52aadb commit 4d2435f
Show file tree
Hide file tree
Showing 2 changed files with 568 additions and 12 deletions.
260 changes: 248 additions & 12 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::visit::{MutVisitor, PlaceContext};
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::{
self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, GenericArgs, Ty, TyCtxt,
};
use rustc_middle::{bug, span_bug};
use rustc_mir_dataflow::elaborate_drops::{self, DropElaborator, DropFlagMode, DropStyle};
use rustc_span::source_map::Spanned;
use rustc_span::source_map::{dummy_spanned, Spanned};
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
use rustc_target::spec::abi::Abi;
Expand All @@ -23,10 +24,46 @@ use crate::{
instsimplify, mentioned_items, pass_manager as pm, remove_noop_landing_pads, simplify,
};

mod async_destructor_ctor;

pub fn provide(providers: &mut Providers) {
providers.mir_shims = make_shim;
}

// Replace Pin<&mut ImplCoroutine> accesses (_1.0) into Pin<&mut ProxyCoroutine> acceses
struct FixProxyFutureDropVisitor<'tcx> {
tcx: TyCtxt<'tcx>,
replace_to: Local,
}

impl<'tcx> MutVisitor<'tcx> for FixProxyFutureDropVisitor<'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn visit_place(
&mut self,
place: &mut Place<'tcx>,
_context: PlaceContext,
_location: Location,
) {
if place.local == Local::from_u32(1) {
if place.projection.len() == 1 {
assert!(matches!(
place.projection.first(),
Some(ProjectionElem::Field(FieldIdx::ZERO, _))
));
*place = Place::from(self.replace_to);
} else if place.projection.len() == 2 {
assert!(matches!(place.projection[0], ProjectionElem::Field(FieldIdx::ZERO, _)));
assert!(matches!(place.projection[1], ProjectionElem::Deref));
*place =
Place::from(self.replace_to).project_deeper(&[ProjectionElem::Deref], self.tcx);
}
}
}
}

fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceKind<'tcx>) -> Body<'tcx> {
debug!("make_shim({:?})", instance);

Expand Down Expand Up @@ -70,7 +107,6 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceKind<'tcx>) -> Body<

build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
}

ty::InstanceKind::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
receiver_by_ref,
Expand All @@ -81,8 +117,6 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceKind<'tcx>) -> Body<
}

ty::InstanceKind::DropGlue(def_id, ty) => {
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
// of this function. Is this intentional?
if let Some(ty::Coroutine(coroutine_def_id, args)) = ty.map(Ty::kind) {
let coroutine_body = tcx.optimized_mir(*coroutine_def_id);

Expand Down Expand Up @@ -119,23 +153,225 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceKind<'tcx>) -> Body<
],
Some(MirPhase::Runtime(RuntimePhase::Optimized)),
);

return body;
}

build_drop_shim(tcx, def_id, ty)
}
ty::InstanceKind::ThreadLocalShim(..) => build_thread_local_shim(tcx, instance),
ty::InstanceKind::CloneShim(def_id, ty) => build_clone_shim(tcx, def_id, ty),
ty::InstanceKind::FnPtrAddrShim(def_id, ty) => build_fn_ptr_addr_shim(tcx, def_id, ty),
ty::InstanceKind::FutureDropPollShim(_def_id, _proxy_ty, _impl_ty) => {
todo!()
ty::InstanceKind::FutureDropPollShim(def_id, proxy_ty, impl_ty) => {
let ty::Coroutine(coroutine_def_id, impl_args) = impl_ty.kind() else {
bug!("FutureDropPollShim not for coroutine impl type: ({:?})", instance);
};

let span = tcx.def_span(def_id);
let source_info = SourceInfo::outermost(span);

let pin_proxy_layout_local = Local::new(1);
let cor_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, impl_ty);
let proxy_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, proxy_ty);
// taking _1.0.0 (impl from Pin, impl from proxy)
let proxy_ref_place = Place::from(pin_proxy_layout_local)
.project_deeper(&[PlaceElem::Field(FieldIdx::ZERO, proxy_ref)], tcx);
let impl_ref_place = |proxy_ref_local: Local| {
Place::from(proxy_ref_local).project_deeper(
&[
PlaceElem::Deref,
PlaceElem::Downcast(None, VariantIdx::ZERO),
PlaceElem::Field(FieldIdx::ZERO, cor_ref),
],
tcx,
)
};

if tcx.is_templated_coroutine(*coroutine_def_id) {
// ret_ty = `Poll<()>`
let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, None));
let ret_ty = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
// env_ty = `Pin<&mut proxy_ty>`
let pin_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Pin, None));
let env_ty = Ty::new_adt(tcx, pin_adt_ref, tcx.mk_args(&[proxy_ref.into()]));
// sig = `fn (Pin<&mut proxy_ty>, &mut Context) -> Poll<()>`
let sig = tcx.mk_fn_sig(
[env_ty, Ty::new_task_context(tcx)],
ret_ty,
false,
hir::Safety::Safe,
rustc_target::spec::abi::Abi::Rust,
);
let mut locals = local_decls_for_sig(&sig, span);
let mut blocks = IndexVec::with_capacity(3);

let proxy_ref_local = locals.push(LocalDecl::new(proxy_ref, span));
let cor_ref_local = locals.push(LocalDecl::new(cor_ref, span));
let cor_ref_place = Place::from(cor_ref_local);

let call_bb = BasicBlock::new(1);
let return_bb = BasicBlock::new(2);

let assign1 = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
Place::from(proxy_ref_local),
Rvalue::CopyForDeref(proxy_ref_place),
))),
};
let assign2 = Statement {
source_info,
kind: StatementKind::Assign(Box::new((
cor_ref_place,
Rvalue::CopyForDeref(impl_ref_place(proxy_ref_local)),
))),
};

// cor_pin_ty = `Pin<&mut cor_ref>`
let cor_pin_ty = Ty::new_adt(tcx, pin_adt_ref, tcx.mk_args(&[cor_ref.into()]));
let cor_pin_place = Place::from(locals.push(LocalDecl::new(cor_pin_ty, span)));

let pin_fn = tcx.require_lang_item(LangItem::PinNewUnchecked, Some(span));
// call Pin<FutTy>::new_unchecked(&mut impl_cor)
blocks.push(BasicBlockData {
statements: vec![assign1, assign2],
terminator: Some(Terminator {
source_info,
kind: TerminatorKind::Call {
func: Operand::function_handle(tcx, pin_fn, [cor_ref.into()], span),
args: [dummy_spanned(Operand::Move(cor_ref_place))].into(),
destination: cor_pin_place,
target: Some(call_bb),
unwind: UnwindAction::Continue,
call_source: CallSource::Misc,
fn_span: span,
},
}),
is_cleanup: false,
});
// When dropping async drop coroutine, we continue its execution:
// we call impl::poll (impl_layout, ctx)
let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, None);
let resume_ctx = Place::from(Local::new(2));
blocks.push(BasicBlockData {
statements: vec![],
terminator: Some(Terminator {
source_info,
kind: TerminatorKind::Call {
func: Operand::function_handle(tcx, poll_fn, [impl_ty.into()], span),
args: [
dummy_spanned(Operand::Move(cor_pin_place)),
dummy_spanned(Operand::Move(resume_ctx)),
]
.into(),
destination: Place::return_place(),
target: Some(return_bb),
unwind: UnwindAction::Continue,
call_source: CallSource::Misc,
fn_span: span,
},
}),
is_cleanup: false,
});
blocks.push(BasicBlockData {
statements: vec![],
terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
is_cleanup: false,
});

let source = MirSource::from_instance(instance);
let mut body = new_body(source, blocks, locals, sig.inputs().len(), span);
pm::run_passes(
tcx,
&mut body,
&[
&mentioned_items::MentionedItems,
&abort_unwinding_calls::AbortUnwindingCalls,
&add_call_guards::CriticalCallEdges,
],
Some(MirPhase::Runtime(RuntimePhase::Optimized)),
);
return body;
}
// future drop poll for async drop must be resolved to standart poll (AsyncDropGlue)
assert!(!tcx.is_templated_coroutine(*coroutine_def_id));

// converting `(_1: Pin<&mut CorLayout>, _2: &mut Context<'_>) -> Poll<()>`
// into `(_1: Pin<&mut ProxyLayout>, _2: &mut Context<'_>) -> Poll<()>`
// let mut _x: &mut CorLayout = &*_1.0.0;
// Replace old _1.0 accesses into _x accesses;
let body = tcx.optimized_mir(*coroutine_def_id).future_drop_poll().unwrap();
let mut body: Body<'tcx> = EarlyBinder::bind(body.clone()).instantiate(tcx, impl_args);
body.source.instance = instance;
body.var_debug_info.clear();
let pin_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Pin, Some(span)));
let args = tcx.mk_args(&[proxy_ref.into()]);
let pin_proxy_ref = Ty::new_adt(tcx, pin_adt_ref, args);

let proxy_ref_local = body.local_decls.push(LocalDecl::new(proxy_ref, span));
let cor_ref_local = body.local_decls.push(LocalDecl::new(cor_ref, span));
FixProxyFutureDropVisitor { tcx, replace_to: cor_ref_local }.visit_body(&mut body);
// Now changing first arg from Pin<&mut ImplCoroutine> to Pin<&mut ProxyCoroutine>
body.local_decls[pin_proxy_layout_local] = LocalDecl::new(pin_proxy_ref, span);

{
let bb: &mut BasicBlockData<'tcx> = &mut body.basic_blocks_mut()[START_BLOCK];
// _tmp = _1.0 : Pin<&ProxyLayout> ==> &ProxyLayout
bb.statements.insert(
0,
Statement {
source_info,
kind: StatementKind::Assign(Box::new((
Place::from(proxy_ref_local),
Rvalue::CopyForDeref(proxy_ref_place),
))),
},
);
bb.statements.insert(
1,
Statement {
source_info,
kind: StatementKind::Assign(Box::new((
Place::from(cor_ref_local),
Rvalue::CopyForDeref(impl_ref_place(proxy_ref_local)),
))),
},
);
}

pm::run_passes(
tcx,
&mut body,
&[
&mentioned_items::MentionedItems,
&abort_unwinding_calls::AbortUnwindingCalls,
&add_call_guards::CriticalCallEdges,
],
Some(MirPhase::Runtime(RuntimePhase::Optimized)),
);
debug!("make_shim({:?}) = {:?}", instance, body);
return body;
}
ty::InstanceKind::AsyncDropGlue(_def_id, _ty) => {
todo!()
ty::InstanceKind::AsyncDropGlue(def_id, ty) => {
let mut body = async_destructor_ctor::build_async_drop_shim(tcx, def_id, ty);

pm::run_passes(
tcx,
&mut body,
&[
&mentioned_items::MentionedItems,
&simplify::SimplifyCfg::MakeShim,
&crate::reveal_all::RevealAll,
&crate::coroutine::StateTransform,
],
Some(MirPhase::Runtime(RuntimePhase::PostCleanup)),
);
debug!("make_shim({:?}) = {:?}", instance, body);
return body;
}
ty::InstanceKind::AsyncDropGlueCtorShim(_def_id, _ty) => {
bug!("AsyncDropGlueCtorShim in re-working ({:?})", instance)

ty::InstanceKind::AsyncDropGlueCtorShim(def_id, ty) => {
let body = async_destructor_ctor::build_async_destructor_ctor_shim(tcx, def_id, ty);
debug!("make_shim({:?}) = {:?}", instance, body);
return body;
}
ty::InstanceKind::Virtual(..) => {
bug!("InstanceKind::Virtual ({:?}) is for direct calls only", instance)
Expand Down
Loading

0 comments on commit 4d2435f

Please sign in to comment.