Skip to content

Commit fb0f12b

Browse files
authored
Merge pull request github#19756 from paldepind/rust/type-parameters-default
Rust: Type inference uses defaults for type parameters
2 parents 2a51749 + 8fe737c commit fb0f12b

File tree

5 files changed

+2717
-2640
lines changed

5 files changed

+2717
-2640
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ abstract class Type extends TType {
4242
/** Gets the `i`th type parameter of this type, if any. */
4343
abstract TypeParameter getTypeParameter(int i);
4444

45+
/** Gets the default type for the `i`th type parameter, if any. */
46+
TypeMention getTypeParameterDefault(int i) { none() }
47+
4548
/** Gets a type parameter of this type. */
4649
final TypeParameter getATypeParameter() { result = this.getTypeParameter(_) }
4750

@@ -87,6 +90,10 @@ class StructType extends StructOrEnumType, TStruct {
8790
result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i))
8891
}
8992

93+
override TypeMention getTypeParameterDefault(int i) {
94+
result = struct.getGenericParamList().getTypeParam(i).getDefaultType()
95+
}
96+
9097
override string toString() { result = struct.getName().getText() }
9198

9299
override Location getLocation() { result = struct.getLocation() }
@@ -108,6 +115,10 @@ class EnumType extends StructOrEnumType, TEnum {
108115
result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i))
109116
}
110117

118+
override TypeMention getTypeParameterDefault(int i) {
119+
result = enum.getGenericParamList().getTypeParam(i).getDefaultType()
120+
}
121+
111122
override string toString() { result = enum.getName().getText() }
112123

113124
override Location getLocation() { result = enum.getLocation() }
@@ -133,6 +144,10 @@ class TraitType extends Type, TTrait {
133144
any(AssociatedTypeTypeParameter param | param.getTrait() = trait and param.getIndex() = i)
134145
}
135146

147+
override TypeMention getTypeParameterDefault(int i) {
148+
result = trait.getGenericParamList().getTypeParam(i).getDefaultType()
149+
}
150+
136151
override string toString() { result = trait.toString() }
137152

138153
override Location getLocation() { result = trait.getLocation() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1594,6 +1594,11 @@ private module Debug {
15941594
result = resolveMethodCallTarget(mce)
15951595
}
15961596

1597+
predicate debugTypeMention(TypeMention tm, TypePath path, Type type) {
1598+
tm = getRelevantLocatable() and
1599+
tm.resolveTypeAt(path) = type
1600+
}
1601+
15971602
pragma[nomagic]
15981603
private int countTypes(AstNode n, TypePath path, Type t) {
15991604
t = inferType(n, path) and

rust/ql/lib/codeql/rust/internal/TypeMention.qll

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
8888
override TypeMention getTypeArgument(int i) {
8989
result = path.getSegment().getGenericArgList().getTypeArg(i)
9090
or
91+
// If a type argument is not given in the path, then we use the default for
92+
// the type parameter if one exists for the type.
93+
not exists(path.getSegment().getGenericArgList().getTypeArg(i)) and
94+
result = this.resolveType().getTypeParameterDefault(i)
95+
or
9196
// `Self` paths inside `impl` blocks have implicit type arguments that are
9297
// the type parameters of the `impl` block. For example, in
9398
//

rust/ql/test/library-tests/type-inference/main.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ mod field_access {
1414
}
1515

1616
#[derive(Debug)]
17-
struct GenericThing<A> {
17+
struct GenericThing<A = bool> {
1818
a: A,
1919
}
2020

@@ -27,6 +27,11 @@ mod field_access {
2727
println!("{:?}", x.a); // $ fieldof=MyThing
2828
}
2929

30+
fn default_field_access(x: GenericThing) {
31+
let a = x.a; // $ fieldof=GenericThing type=a:bool
32+
println!("{:?}", a);
33+
}
34+
3035
fn generic_field_access() {
3136
// Explicit type argument
3237
let x = GenericThing::<S> { a: S }; // $ type=x:A.S
@@ -472,16 +477,16 @@ mod type_parameter_bounds {
472477
println!("{:?}", s); // $ type=s:S1
473478
}
474479

475-
trait Pair<P1, P2> {
480+
trait Pair<P1 = bool, P2 = i64> {
476481
fn fst(self) -> P1;
477482

478483
fn snd(self) -> P2;
479484
}
480485

481486
fn call_trait_per_bound_with_type_1<T: Pair<S1, S2>>(x: T, y: T) {
482487
// The type in the type parameter bound determines the return type.
483-
let s1 = x.fst(); // $ method=fst
484-
let s2 = y.snd(); // $ method=snd
488+
let s1 = x.fst(); // $ method=fst type=s1:S1
489+
let s2 = y.snd(); // $ method=snd type=s2:S2
485490
println!("{:?}, {:?}", s1, s2);
486491
}
487492

@@ -491,6 +496,20 @@ mod type_parameter_bounds {
491496
let s2 = y.snd(); // $ method=snd
492497
println!("{:?}, {:?}", s1, s2);
493498
}
499+
500+
fn call_trait_per_bound_with_type_3<T: Pair>(x: T, y: T) {
501+
// The type in the type parameter bound determines the return type.
502+
let s1 = x.fst(); // $ method=fst type=s1:bool
503+
let s2 = y.snd(); // $ method=snd type=s2:i64
504+
println!("{:?}, {:?}", s1, s2);
505+
}
506+
507+
fn call_trait_per_bound_with_type_4<T: Pair<u8>>(x: T, y: T) {
508+
// The type in the type parameter bound determines the return type.
509+
let s1 = x.fst(); // $ method=fst type=s1:u8
510+
let s2 = y.snd(); // $ method=snd type=s2:i64
511+
println!("{:?}, {:?}", s1, s2);
512+
}
494513
}
495514

496515
mod function_trait_bounds {

0 commit comments

Comments
 (0)