Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve ProjectionPredicate during normalization #830

Merged
merged 9 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/flux-driver/src/collector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use rustc_errors::ErrorGuaranteed;
use rustc_hir::{
self as hir,
def::DefKind,
def_id::{DefId, LocalDefId, CRATE_DEF_ID},
def_id::{LocalDefId, CRATE_DEF_ID},
EnumDef, ImplItemKind, Item, ItemKind, OwnerId, VariantData, CRATE_OWNER_ID,
};
use rustc_middle::ty::TyCtxt;
Expand Down
115 changes: 109 additions & 6 deletions crates/flux-middle/src/rty/projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ use rustc_trait_selection::traits::SelectionContext;

use super::{
fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable},
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, Expr, ExprKind,
GenericArg, ProjectionPredicate, RefineArgs, Region, SubsetTy, Ty, TyKind,
subst::{GenericsSubstDelegate, GenericsSubstFolder},
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, EarlyBinder, Expr,
ExprKind, GenericArg, ProjectionPredicate, RefineArgs, Region, SubsetTy, Ty, TyKind,
};
use crate::{
global_env::GlobalEnv,
Expand Down Expand Up @@ -136,7 +137,65 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
Ok((true, ty))
}

fn confirm_candidate(&self, candidate: Candidate, obligation: &AliasTy) -> QueryResult<Ty> {
fn find_resolved_predicates(
&self,
subst: &mut TVarSubst,
preds: Vec<EarlyBinder<ProjectionPredicate>>,
) -> (Vec<ProjectionPredicate>, Vec<EarlyBinder<ProjectionPredicate>>) {
let mut resolved = vec![];
let mut unresolved = vec![];
for pred in preds {
let term = pred.clone().skip_binder().term;
let alias_ty = pred.clone().map(|p| p.projection_ty);
match subst.instantiate_partial(alias_ty) {
Some(projection_ty) => {
let pred = ProjectionPredicate { projection_ty, term };
resolved.push(pred);
}
None => unresolved.push(pred.clone()),
}
}
(resolved, unresolved)
}

// See issue-829*.rs for an example of what this function is for.
fn resolve_projection_predicates(
&mut self,
subst: &mut TVarSubst,
impl_def_id: DefId,
) -> QueryResult {
let mut projection_preds: Vec<_> = self
.genv
.predicates_of(impl_def_id)?
.skip_binder()
.predicates
.iter()
.filter_map(|pred| {
if let ClauseKind::Projection(pred) = pred.kind_skipping_binder() {
Some(EarlyBinder(pred.clone()))
} else {
None
}
})
.collect();

while !projection_preds.is_empty() {
let (resolved, unresolved) = self.find_resolved_predicates(subst, projection_preds);

if resolved.is_empty() {
break; // failed: there is some unresolved projection pred!
}
for p in resolved {
let obligation = &p.projection_ty;
let (_, ty) = self.normalize_projection_ty(obligation)?;
subst.tys(&p.term, &ty);
}
projection_preds = unresolved;
}
Ok(())
}

fn confirm_candidate(&mut self, candidate: Candidate, obligation: &AliasTy) -> QueryResult<Ty> {
match candidate {
Candidate::ParamEnv(pred) | Candidate::TraitDef(pred) => Ok(pred.term),
Candidate::UserDefinedImpl(impl_def_id) => {
Expand All @@ -145,9 +204,9 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
// and the id of a rust impl block
// impl<T, A: Allocator> Iterator for IntoIter<T, A>

// 1. Match the self type of the rust impl block and the flux self type of the obligation
// 1. MATCH the self type of the rust impl block and the flux self type of the obligation
// to infer a substitution
// IntoIter<{v. i32[v] | v > 0}, Global> against IntoIter<T, A>
// IntoIter<{v. i32[v] | v > 0}, Global> MATCH IntoIter<T, A>
// => {T -> {v. i32[v] | v > 0}, A -> Global}

let impl_trait_ref = self
Expand All @@ -162,9 +221,13 @@ impl<'genv, 'tcx, 'cx> Normalizer<'genv, 'tcx, 'cx> {
for (a, b) in iter::zip(&impl_trait_ref.args, &obligation.args) {
subst.generic_args(a, b);
}

// 2. Gather the ProjectionPredicates and solve them see issue-808.rs
self.resolve_projection_predicates(&mut subst, impl_def_id)?;

let args = subst.finish(self.tcx(), generics);

// 2. Get the associated type in the impl block and apply the substitution to it
// 3. Get the associated type in the impl block and apply the substitution to it
let assoc_type_id = self
.tcx()
.associated_items(impl_def_id)
Expand Down Expand Up @@ -316,11 +379,51 @@ struct TVarSubst {
args: Vec<Option<GenericArg>>,
}

impl GenericsSubstDelegate for &TVarSubst {
type Error = ();

fn ty_for_param(&mut self, param_ty: rustc_middle::ty::ParamTy) -> Result<Ty, Self::Error> {
match self.args.get(param_ty.index as usize) {
Some(Some(GenericArg::Ty(ty))) => Ok(ty.clone()),
Some(None) => Err(()),
arg => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
}
}

fn sort_for_param(
&mut self,
_param_ty: rustc_middle::ty::ParamTy,
) -> Result<super::Sort, Self::Error> {
tracked_span_bug!()
}

fn ctor_for_param(&mut self, _param_ty: rustc_middle::ty::ParamTy) -> super::SubsetTyCtor {
tracked_span_bug!()
}

fn region_for_param(&mut self, _ebr: rustc_middle::ty::EarlyParamRegion) -> Region {
tracked_span_bug!()
}

fn expr_for_param_const(&self, _param_const: rustc_middle::ty::ParamConst) -> Expr {
tracked_span_bug!()
}

fn const_for_param(&mut self, _param: &Const) -> Const {
tracked_span_bug!()
}
}

impl TVarSubst {
fn new(generics: &rustc_middle::ty::Generics) -> Self {
Self { args: vec![None; generics.count()] }
}

fn instantiate_partial<T: TypeFoldable>(&mut self, pred: EarlyBinder<T>) -> Option<T> {
let mut folder = GenericsSubstFolder::new(&*self, &[]);
pred.skip_binder().try_fold_with(&mut folder).ok()
}

fn finish<'tcx>(
self,
tcx: TyCtxt<'tcx>,
Expand Down
10 changes: 5 additions & 5 deletions crates/flux-middle/src/rty/subst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ pub trait GenericsSubstDelegate {
type Error = !;

fn sort_for_param(&mut self, param_ty: ParamTy) -> Result<Sort, Self::Error>;
fn ty_for_param(&mut self, param_ty: ParamTy) -> Ty;
fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, Self::Error>;
fn ctor_for_param(&mut self, param_ty: ParamTy) -> SubsetTyCtor;
fn region_for_param(&mut self, ebr: EarlyParamRegion) -> Region;
fn expr_for_param_const(&self, param_const: ParamConst) -> Expr;
Expand All @@ -358,9 +358,9 @@ impl<'a, 'tcx> GenericsSubstDelegate for GenericArgsDelegate<'a, 'tcx> {
}
}

fn ty_for_param(&mut self, param_ty: ParamTy) -> Ty {
fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, !> {
match self.0.get(param_ty.index as usize) {
Some(GenericArg::Ty(ty)) => ty.clone(),
Some(GenericArg::Ty(ty)) => Ok(ty.clone()),
Some(arg) => tracked_span_bug!("expected type for generic parameter, found `{arg:?}`"),
None => tracked_span_bug!("type parameter out of range {param_ty:?}"),
}
Expand Down Expand Up @@ -433,7 +433,7 @@ where
(self.sort_for_param)(param_ty)
}

fn ty_for_param(&mut self, param_ty: ParamTy) -> Ty {
fn ty_for_param(&mut self, param_ty: ParamTy) -> Result<Ty, E> {
bug!("unexpected type param {param_ty:?}");
}

Expand Down Expand Up @@ -497,7 +497,7 @@ impl<D: GenericsSubstDelegate> FallibleTypeFolder for GenericsSubstFolder<'_, D>

fn try_fold_ty(&mut self, ty: &Ty) -> Result<Ty, D::Error> {
match ty.kind() {
TyKind::Param(param_ty) => Ok(self.delegate.ty_for_param(*param_ty)),
TyKind::Param(param_ty) => self.delegate.ty_for_param(*param_ty),
TyKind::Indexed(BaseTy::Param(param_ty), idx) => {
let idx = idx.try_fold_with(self)?;
Ok(self
Expand Down
24 changes: 24 additions & 0 deletions tests/tests/pos/surface/issue-829.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
trait Trait1 {
type Assoc1;
}

impl Trait1 for i32 {
type Assoc1 = bool;
}

trait Trait2 {
type Assoc2;
}

struct S<T> {
fld: T,
}

impl<T1, T2> Trait2 for S<T2>
where
T2: Trait1<Assoc1 = T1>,
{
type Assoc2 = T1;
}

fn test(x: <S<i32> as Trait2>::Assoc2) {}
25 changes: 25 additions & 0 deletions tests/tests/pos/surface/issue-829b.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
trait Trait1 {
type Assoc1;
}

impl Trait1 for i32 {
type Assoc1 = i32;
}

trait Trait2 {
type Assoc2;
}

struct S<T> {
fld: T,
}

impl<T1, T2, T3> Trait2 for S<T1>
where
T2: Trait1<Assoc1 = T3>,
T1: Trait1<Assoc1 = T2>,
{
type Assoc2 = T1;
}

fn test(x: <S<i32> as Trait2>::Assoc2) {}
Loading