Skip to content

Commit 6590777

Browse files
authored
Merge pull request #19062 from paldepind/rust-ti-1
Rust: Improve handling of trait bounds
2 parents e2d6643 + b02a249 commit 6590777

File tree

5 files changed

+950
-705
lines changed

5 files changed

+950
-705
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

+11-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ abstract private class StructOrEnumType extends Type {
8181
)
8282
}
8383

84+
/** Gets all of the fully parametric `impl` blocks that target this type. */
8485
final override ImplMention getABaseTypeMention() {
8586
this.asItemNode() = result.resolveSelfTy() and
8687
result.isFullyParametric()
@@ -153,6 +154,7 @@ class TraitType extends Type, TTrait {
153154
result = trait.getTypeBoundList().getABound().getTypeRepr()
154155
}
155156

157+
/** Gets any of the trait bounds of this trait. */
156158
override TypeMention getABaseTypeMention() { result = this.getABoundMention() }
157159

158160
override string toString() { result = trait.toString() }
@@ -308,11 +310,19 @@ class TypeParamTypeParameter extends TypeParameter, TTypeParamTypeParameter {
308310

309311
TypeParam getTypeParam() { result = typeParam }
310312

311-
override Function getMethod(string name) { result = typeParam.(ItemNode).getASuccessor(name) }
313+
override Function getMethod(string name) {
314+
// NOTE: If the type parameter has trait bounds, then this finds methods
315+
// on the bounding traits.
316+
result = typeParam.(ItemNode).getASuccessor(name)
317+
}
312318

313319
override string toString() { result = typeParam.toString() }
314320

315321
override Location getLocation() { result = typeParam.getLocation() }
322+
323+
final override TypeMention getABaseTypeMention() {
324+
result = typeParam.getTypeBoundList().getABound().getTypeRepr()
325+
}
316326
}
317327

318328
/** An implicit reference type parameter. */

rust/ql/lib/codeql/rust/internal/TypeInference.qll

+4
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,10 @@ private Type inferImplicitSelfType(SelfParam self, TypePath path) {
234234
)
235235
}
236236

237+
/**
238+
* Gets any of the types mentioned in `path` that corresponds to the type
239+
* parameter `tp`.
240+
*/
237241
private TypeMention getExplicitTypeArgMention(Path path, TypeParam tp) {
238242
exists(int i |
239243
result = path.getPart().getGenericArgList().getTypeArg(pragma[only_bind_into](i)) and

rust/ql/test/library-tests/type-inference/main.rs

+107-2
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ mod field_access {
3636
let y = GenericThing { a: S };
3737
println!("{:?}", x.a);
3838

39-
// The type of the field `a` can only be infered from the concrete type
39+
// The type of the field `a` can only be inferred from the concrete type
4040
// in the struct declaration.
4141
let x = OptionS {
4242
a: MyOption::MyNone(),
4343
};
4444
println!("{:?}", x.a);
4545

46-
// The type of the field `a` can only be infered from the type argument
46+
// The type of the field `a` can only be inferred from the type argument
4747
let x = GenericThing::<MyOption<S>> {
4848
a: MyOption::MyNone(),
4949
};
@@ -191,6 +191,68 @@ mod method_non_parametric_trait_impl {
191191
}
192192
}
193193

194+
mod type_parameter_bounds {
195+
use std::fmt::Debug;
196+
197+
#[derive(Debug)]
198+
struct S1;
199+
200+
#[derive(Debug)]
201+
struct S2;
202+
203+
// Two traits with the same method name.
204+
205+
trait FirstTrait<FT> {
206+
fn method(self) -> FT;
207+
}
208+
209+
trait SecondTrait<ST> {
210+
fn method(self) -> ST;
211+
}
212+
213+
fn call_first_trait_per_bound<I: Debug, T: SecondTrait<I>>(x: T) {
214+
// The type parameter bound determines which method this call is resolved to.
215+
let s1 = x.method();
216+
println!("{:?}", s1);
217+
}
218+
219+
fn call_second_trait_per_bound<I: Debug, T: SecondTrait<I>>(x: T) {
220+
// The type parameter bound determines which method this call is resolved to.
221+
let s2 = x.method();
222+
println!("{:?}", s2);
223+
}
224+
225+
fn trait_bound_with_type<T: FirstTrait<S1>>(x: T) {
226+
let s = x.method();
227+
println!("{:?}", s);
228+
}
229+
230+
fn trait_per_bound_with_type<T: FirstTrait<S1>>(x: T) {
231+
let s = x.method();
232+
println!("{:?}", s);
233+
}
234+
235+
trait Pair<P1, P2> {
236+
fn fst(self) -> P1;
237+
238+
fn snd(self) -> P2;
239+
}
240+
241+
fn call_trait_per_bound_with_type_1<T: Pair<S1, S2>>(x: T, y: T) {
242+
// The type in the type parameter bound determines the return type.
243+
let s1 = x.fst();
244+
let s2 = y.snd();
245+
println!("{:?}, {:?}", s1, s2);
246+
}
247+
248+
fn call_trait_per_bound_with_type_2<T2: Debug, T: Pair<S1, T2>>(x: T, y: T) {
249+
// The type in the type parameter bound determines the return type.
250+
let s1 = x.fst();
251+
let s2 = y.snd();
252+
println!("{:?}, {:?}", s1, s2);
253+
}
254+
}
255+
194256
mod function_trait_bounds {
195257
#[derive(Debug)]
196258
struct MyThing<A> {
@@ -443,6 +505,49 @@ mod function_trait_bounds_2 {
443505
}
444506
}
445507

508+
mod type_aliases {
509+
#[derive(Debug)]
510+
enum PairOption<Fst, Snd> {
511+
PairNone(),
512+
PairFst(Fst),
513+
PairSnd(Snd),
514+
PairBoth(Fst, Snd),
515+
}
516+
517+
#[derive(Debug)]
518+
struct S1;
519+
520+
#[derive(Debug)]
521+
struct S2;
522+
523+
#[derive(Debug)]
524+
struct S3;
525+
526+
// Non-generic type alias that fully applies the generic type
527+
type MyPair = PairOption<S1, S2>;
528+
529+
// Generic type alias that partially applies the generic type
530+
type AnotherPair<Thr> = PairOption<S2, Thr>;
531+
532+
pub fn f() {
533+
// Type can be inferred from the constructor
534+
let p1: MyPair = PairOption::PairBoth(S1, S2);
535+
println!("{:?}", p1);
536+
537+
// Type can be only inferred from the type alias
538+
let p2: MyPair = PairOption::PairNone(); // types for `Fst` and `Snd` missing
539+
println!("{:?}", p2);
540+
541+
// First type from alias, second from constructor
542+
let p3: AnotherPair<_> = PairOption::PairSnd(S3); // type for `Fst` missing
543+
println!("{:?}", p3);
544+
545+
// First type from alias definition, second from argument to alias
546+
let p3: AnotherPair<S3> = PairOption::PairNone(); // type for `Snd` missing, spurious `S3` for `Fst`
547+
println!("{:?}", p3);
548+
}
549+
}
550+
446551
mod option_methods {
447552
#[derive(Debug)]
448553
enum MyOption<T> {

0 commit comments

Comments
 (0)