Skip to content

Commit b39acb1

Browse files
committed
Rust: Restrict type propagation into arguments
1 parent dcb2bf2 commit b39acb1

File tree

12 files changed

+114
-561
lines changed

12 files changed

+114
-561
lines changed

rust/ql/lib/codeql/rust/internal/Type.qll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ newtype TType =
5151
TSliceType() or
5252
TNeverType() or
5353
TPtrType() or
54+
TContextType() or
5455
TTupleTypeParameter(int arity, int i) { exists(TTuple(arity)) and i in [0 .. arity - 1] } or
5556
TTypeParamTypeParameter(TypeParam t) or
5657
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
@@ -371,6 +372,14 @@ class PtrType extends Type, TPtrType {
371372
override Location getLocation() { result instanceof EmptyLocation }
372373
}
373374

375+
class ContextType extends Type, TContextType {
376+
override TypeParameter getPositionalTypeParameter(int i) { none() }
377+
378+
override string toString() { result = "(context typed)" }
379+
380+
override Location getLocation() { result instanceof EmptyLocation }
381+
}
382+
374383
/** A type parameter. */
375384
abstract class TypeParameter extends Type {
376385
override TypeParameter getPositionalTypeParameter(int i) { none() }

rust/ql/lib/codeql/rust/internal/TypeInference.qll

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,7 +1961,30 @@ private Type inferMethodCallType0(
19611961
) {
19621962
exists(TypePath path0 |
19631963
n = a.getNodeAt(apos) and
1964-
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
1964+
(
1965+
result = MethodCallMatching::inferAccessType(a, derefChainBorrow, apos, path0)
1966+
or
1967+
apos.isReturn() and
1968+
exists(Method target, TypeParameter tp |
1969+
target = a.getTarget(derefChainBorrow) and
1970+
assocFunctionContextTypedAt(target, _, path0, tp) and
1971+
not exists(TypeArgumentPosition tapos |
1972+
exists(int i |
1973+
i = tapos.asMethodTypeArgumentPosition() and
1974+
tp = TTypeParamTypeParameter(target.getGenericParamList().getTypeParam(i))
1975+
)
1976+
or
1977+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
1978+
|
1979+
exists(a.getTypeArgument(tapos, _))
1980+
) and
1981+
not (
1982+
tp instanceof TSelfTypeParameter and
1983+
exists(getCallExprTypeQualifier(a, _))
1984+
)
1985+
) and
1986+
result = TContextType()
1987+
)
19651988
|
19661989
if
19671990
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
@@ -1973,6 +1996,9 @@ private Type inferMethodCallType0(
19731996
)
19741997
}
19751998

1999+
pragma[nomagic]
2000+
private TypePath getContextTypePath(AstNode n) { inferType(n, result) = TContextType() }
2001+
19762002
/**
19772003
* Gets the type of `n` at `path`, where `n` is either a method call or an
19782004
* argument/receiver of a method call.
@@ -1983,7 +2009,8 @@ private Type inferMethodCallType(AstNode n, TypePath path) {
19832009
MethodCallMatchingInput::Access a, MethodCallMatchingInput::AccessPosition apos,
19842010
string derefChainBorrow, TypePath path0
19852011
|
1986-
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0)
2012+
result = inferMethodCallType0(a, apos, n, derefChainBorrow, path0) and
2013+
if not apos.isReturn() then path.startsWith(getContextTypePath(n)) else any()
19872014
|
19882015
(
19892016
not apos.isSelf()
@@ -2171,6 +2198,12 @@ private module NonMethodResolution {
21712198
or
21722199
result = this.resolveCallTargetRec()
21732200
}
2201+
2202+
pragma[nomagic]
2203+
Function resolveTraitFunction() {
2204+
this.(Call).hasTrait() and
2205+
result = this.getPathResolutionResolved()
2206+
}
21742207
}
21752208

21762209
private newtype TCallAndBlanketPos =
@@ -2431,7 +2464,30 @@ pragma[nomagic]
24312464
private Type inferNonMethodCallType(AstNode n, TypePath path) {
24322465
exists(NonMethodCallMatchingInput::Access a, NonMethodCallMatchingInput::AccessPosition apos |
24332466
n = a.getNodeAt(apos) and
2467+
if not apos.isReturn() then path.startsWith(getContextTypePath(n)) else any()
2468+
|
24342469
result = NonMethodCallMatching::inferAccessType(a, apos, path)
2470+
or
2471+
apos.isReturn() and
2472+
exists(Function target, TypeParameter tp |
2473+
target = [a.getTarget().(Function), a.resolveTraitFunction()] and
2474+
assocFunctionContextTypedAt(target, _, path, tp) and
2475+
not exists(TypeArgumentPosition tapos |
2476+
// exists(int i |
2477+
// i = tapos.asMethodTypeArgumentPosition() and
2478+
// tp = TTypeParamTypeParameter(target.getGenericParamList().getTypeParam(i))
2479+
// )
2480+
// or
2481+
TTypeParamTypeParameter(tapos.asTypeParam()) = tp
2482+
|
2483+
exists(a.getTypeArgument(tapos, _))
2484+
) and
2485+
not (
2486+
tp instanceof TSelfTypeParameter and
2487+
exists(getCallExprTypeQualifier(a, _))
2488+
)
2489+
) and
2490+
result = TContextType()
24352491
)
24362492
}
24372493

@@ -2510,7 +2566,8 @@ pragma[nomagic]
25102566
private Type inferOperationType(AstNode n, TypePath path) {
25112567
exists(OperationMatchingInput::Access a, OperationMatchingInput::AccessPosition apos |
25122568
n = a.getNodeAt(apos) and
2513-
result = OperationMatching::inferAccessType(a, apos, path)
2569+
result = OperationMatching::inferAccessType(a, apos, path) and
2570+
if not apos.isReturn() then path.startsWith(getContextTypePath(n)) else any()
25142571
)
25152572
}
25162573

@@ -3291,8 +3348,10 @@ private module Debug {
32913348
Locatable getRelevantLocatable() {
32923349
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
32933350
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
3294-
filepath.matches("%/sqlx.rs") and
3295-
startline = [56 .. 60]
3351+
filepath.matches("%/crate/data_derive/src/lib.rs") and
3352+
startline = [48, 74]
3353+
// filepath.matches("%/main.rs") and
3354+
// startline = [2525]
32963355
)
32973356
}
32983357

@@ -3355,7 +3414,7 @@ private module Debug {
33553414
}
33563415

33573416
predicate countTypesForNodeAtLimit(AstNode n, int c) {
3358-
n = getRelevantLocatable() and
3417+
// n = getRelevantLocatable() and
33593418
c = strictcount(Type t, TypePath path | t = debugInferTypeForNodeAtLimit(n, path))
33603419
}
33613420

rust/ql/lib/codeql/rust/internal/typeinference/FunctionType.qll

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,18 @@ module ArgsAreInstantiationsOf<ArgsAreInstantiationsOfInputSig Input> {
418418
)
419419
}
420420
}
421+
422+
pragma[nomagic]
423+
predicate assocFunctionContextTypedAt(
424+
Function f, ImplOrTraitItemNode i, TypePath path, TypeParameter tp
425+
) {
426+
// f.getName().getText() = "default" and
427+
exists(FunctionPosition resPos |
428+
resPos.isReturn() and
429+
assocFunctionTypeAt(f, i, resPos, path, tp)
430+
) and
431+
not exists(FunctionPosition nonResPos |
432+
not nonResPos.isReturn() and
433+
assocFunctionTypeAt(f, i, nonResPos, _, tp)
434+
)
435+
}

rust/ql/test/library-tests/dataflow/sources/CONSISTENCY/PathResolutionConsistency.expected

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ multipleCallTargets
22
| test.rs:113:62:113:77 | ...::from(...) |
33
| test.rs:120:58:120:73 | ...::from(...) |
44
| test.rs:229:22:229:72 | ... .read_to_string(...) |
5+
| test.rs:911:24:911:34 | row.take(...) |
6+
| test.rs:998:24:998:34 | row.take(...) |
57
| test.rs:1096:50:1096:66 | ...::from(...) |
68
| test.rs:1096:50:1096:66 | ...::from(...) |
79
| test_futures_io.rs:35:26:35:63 | pinned.poll_read(...) |
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
multipleCallTargets
2+
| test.rs:288:7:288:36 | ... .as_str() |

0 commit comments

Comments
 (0)