Skip to content

Commit b82cc7b

Browse files
samebGuice Team
authored andcommitted
Eagerly ensure that user-supplied providers always return the expected type. Historically this was implicitly validated upon injection into whatever used the provider (via reflection magic). However, with the shift to MethodHandles, that particular signature check becomes harder to isolate/replicate without adding a bunch of unnecessary overhead. Thus, we instead change this to validate user-supplied providers when they provide the value (which is the soonest it can go wrong).
PiperOrigin-RevId: 752295067
1 parent 73ed742 commit b82cc7b

File tree

10 files changed

+104
-33
lines changed

10 files changed

+104
-33
lines changed

core/src/com/google/inject/internal/BindingProcessor.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,13 @@ public Boolean visit(ProviderInstanceBinding<? extends T> binding) {
154154
ProvisionListenerStackCallback<T> listener =
155155
injector.provisionListenerStore.get((Binding<T>) binding);
156156
int circularFactoryId = injector.circularFactoryIdFactory.next();
157+
Class<? super T> rawType = key.getTypeLiteral().getRawType();
157158
InternalFactory<T> factory =
158159
(initializable.isPresent()
159160
? new InternalFactoryToInitializableAdapter<T>(
160-
initializable.get(), source, listener, circularFactoryId)
161+
rawType, initializable.get(), source, listener, circularFactoryId)
161162
: new ConstantProviderInternalFactory<T>(
162-
provider, source, listener, circularFactoryId));
163+
rawType, provider, source, listener, circularFactoryId));
163164

164165
InternalFactory<? extends T> scopedFactory =
165166
Scoping.scope(key, injector, factory, source, scoping);
@@ -178,6 +179,7 @@ public Boolean visit(ProviderKeyBinding<? extends T> binding) {
178179
@SuppressWarnings("unchecked")
179180
BoundProviderFactory<T> boundProviderFactory =
180181
new BoundProviderFactory<T>(
182+
key.getTypeLiteral().getRawType(),
181183
injector,
182184
providerKey,
183185
source,

core/src/com/google/inject/internal/BoundProviderFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ final class BoundProviderFactory<T> extends ProviderInternalFactory<T> implement
2929
private InternalFactory<? extends jakarta.inject.Provider<? extends T>> providerFactory;
3030

3131
BoundProviderFactory(
32+
Class<? super T> rawType,
3233
InjectorImpl injector,
3334
Key<? extends jakarta.inject.Provider<? extends T>> providerKey,
3435
Object source,
3536
ProvisionListenerStackCallback<T> provisionCallback) {
36-
super(source, injector.circularFactoryIdFactory.next());
37+
super(rawType, source, injector.circularFactoryIdFactory.next());
3738
this.provisionCallback = provisionCallback;
3839
this.injector = injector;
3940
this.providerKey = providerKey;

core/src/com/google/inject/internal/ConstantProviderInternalFactory.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,21 @@
1919
import static com.google.common.base.Preconditions.checkNotNull;
2020

2121
import com.google.inject.spi.Dependency;
22-
import jakarta.inject.Provider;
2322
import javax.annotation.Nullable;
23+
import jakarta.inject.Provider;
2424

2525
/** An InternalFactory that delegates to a constant provider. */
2626
final class ConstantProviderInternalFactory<T> extends ProviderInternalFactory<T> {
2727
private final Provider<T> provider;
2828
@Nullable private final ProvisionListenerStackCallback<T> provisionCallback;
2929

3030
ConstantProviderInternalFactory(
31+
Class<? super T> rawType,
3132
Provider<T> provider,
3233
Object source,
3334
@Nullable ProvisionListenerStackCallback<T> provisionCallback,
3435
int circularFactoryId) {
35-
super(source, circularFactoryId);
36+
super(rawType, source, circularFactoryId);
3637
this.provider = checkNotNull(provider);
3738
this.provisionCallback = provisionCallback;
3839
}

core/src/com/google/inject/internal/ConstructorInjector.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ MethodHandle getConstructHandle(
121121
MethodHandles.dropArguments(
122122
MethodHandles.identity(Object.class), 1, InternalContext.class),
123123
membersHandle);
124-
// Then execute thje membersHandle after constructing the object (and calling
124+
// Then execute the membersHandle after constructing the object (and calling
125125
// finishConstructionAndSetReference)
126126
handle = MethodHandles.foldArguments(membersHandle, handle);
127127
} else {

core/src/com/google/inject/internal/InternalFactoryToInitializableAdapter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ final class InternalFactoryToInitializableAdapter<T> extends ProviderInternalFac
3333
private final Initializable<? extends jakarta.inject.Provider<? extends T>> initializable;
3434

3535
public InternalFactoryToInitializableAdapter(
36+
Class<? super T> rawType,
3637
Initializable<? extends jakarta.inject.Provider<? extends T>> initializable,
3738
Object source,
3839
ProvisionListenerStackCallback<T> provisionCallback,
3940
int circularFactoryId) {
40-
super(source, circularFactoryId);
41+
super(rawType, source, circularFactoryId);
4142
this.provisionCallback = provisionCallback;
4243
this.initializable = checkNotNull(initializable, "provider");
4344
}

core/src/com/google/inject/internal/MembersInjectorImpl.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ MethodHandle getInjectMembersAndNotifyListenersHandle(@Nullable LinkageContext l
307307
// (Object, InternalContext)->void
308308
var notifyListeners =
309309
MethodHandles.dropArguments(getNotifyListenersHandle(), 1, InternalContext.class);
310-
;
311310

312311
local = MethodHandles.foldArguments(notifyListeners, injectMembers);
313312

core/src/com/google/inject/internal/ProvidedByInternalFactory.java

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import com.google.inject.Key;
2323
import com.google.inject.internal.InjectorImpl.JitLimitation;
2424
import com.google.inject.spi.Dependency;
25-
import jakarta.inject.Provider;
2625
import java.lang.invoke.MethodHandle;
2726
import java.lang.invoke.MethodHandles;
27+
import jakarta.inject.Provider;
2828

2929
/**
3030
* An {@link InternalFactory} for {@literal @}{@link ProvidedBy} bindings.
@@ -33,7 +33,6 @@
3333
*/
3434
class ProvidedByInternalFactory<T> extends ProviderInternalFactory<T> implements DelayedInitialize {
3535

36-
private final Class<?> rawType;
3736
private final Class<? extends Provider<?>> providerType;
3837
private final Key<? extends Provider<T>> providerKey;
3938
private InternalFactory<? extends Provider<T>> providerFactory;
@@ -44,8 +43,7 @@ class ProvidedByInternalFactory<T> extends ProviderInternalFactory<T> implements
4443
Class<? extends Provider<?>> providerType,
4544
Key<? extends Provider<T>> providerKey,
4645
int circularFactoryId) {
47-
super(providerKey, circularFactoryId);
48-
this.rawType = rawType;
46+
super(rawType, providerKey, circularFactoryId);
4947
this.providerType = providerType;
5048
this.providerKey = providerKey;
5149
}
@@ -86,13 +84,11 @@ MethodHandleResult makeHandle(LinkageContext context, boolean linked) {
8684
}
8785

8886
@Override
89-
protected MethodHandle provisionHandle(MethodHandle providerHandle) {
90-
// Do normal provisioning and then check that the result is the correct subtype.
91-
MethodHandle invokeProvider = super.provisionHandle(providerHandle);
87+
protected MethodHandle validateReturnTypeHandle(MethodHandle providerHandle) {
9288
return MethodHandles.filterReturnValue(
93-
invokeProvider,
89+
providerHandle,
9490
MethodHandles.insertArguments(
95-
CHECK_SUBTYPE_NOT_PROVIDED_MH, 1, source, providerType, rawType));
91+
CHECK_SUBTYPE_NOT_PROVIDED_MH, 1, source, providerType, providedRawType));
9692
}
9793

9894
private static final MethodHandle CHECK_SUBTYPE_NOT_PROVIDED_MH =
@@ -101,6 +97,8 @@ protected MethodHandle provisionHandle(MethodHandle providerHandle) {
10197
"doCheckSubtypeNotProvided",
10298
methodType(Object.class, Object.class, Object.class, Class.class, Class.class));
10399

100+
// Historically this had a different error check than other providers,
101+
// so we preserve that behavior.
104102
@Keep
105103
static Object doCheckSubtypeNotProvided(
106104
Object result,
@@ -115,18 +113,13 @@ static Object doCheckSubtypeNotProvided(
115113
return result;
116114
}
117115

116+
// Historically this had a different error check than other providers,
117+
// so we preserve that behavior.
118118
@Override
119-
protected T provision(
120-
jakarta.inject.Provider<? extends T> provider,
121-
InternalContext context,
122-
Dependency<?> dependency)
123-
throws InternalProvisionException {
124-
Object o = super.provision(provider, context, dependency);
125-
if (o != null && !rawType.isInstance(o)) {
126-
throw InternalProvisionException.subtypeNotProvided(providerType, rawType).addSource(source);
119+
protected void validateReturnType(T t) throws InternalProvisionException {
120+
if (t != null && !providedRawType.isInstance(t)) {
121+
throw InternalProvisionException.subtypeNotProvided(providerType, providedRawType)
122+
.addSource(source);
127123
}
128-
@SuppressWarnings("unchecked") // protected by isInstance() check above
129-
T t = (T) o;
130-
return t;
131124
}
132125
}

core/src/com/google/inject/internal/ProviderInternalFactory.java

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,25 @@
2020
import static com.google.inject.internal.InternalMethodHandles.castReturnTo;
2121
import static java.lang.invoke.MethodType.methodType;
2222

23+
import com.google.errorprone.annotations.Keep;
2324
import com.google.inject.spi.Dependency;
24-
import jakarta.inject.Provider;
2525
import java.lang.invoke.MethodHandle;
2626
import java.lang.invoke.MethodHandles;
2727
import javax.annotation.Nullable;
28+
import jakarta.inject.Provider;
2829

2930
/**
3031
* Base class for InternalFactories that are used by Providers, to handle circular dependencies.
3132
*
3233
* @author [email protected] (Sam Berlin)
3334
*/
3435
abstract class ProviderInternalFactory<T> extends InternalFactory<T> {
36+
protected final Class<?> providedRawType; // Technically this is "? super T", but this is easier.
3537
protected final Object source;
3638
private final int circularFactoryId;
3739

38-
ProviderInternalFactory(Object source, int circularFactoryId) {
40+
ProviderInternalFactory(Class<?> providedRawType, Object source, int circularFactoryId) {
41+
this.providedRawType = providedRawType;
3942
this.source = checkNotNull(source, "source");
4043
this.circularFactoryId = circularFactoryId;
4144
}
@@ -150,9 +153,36 @@ protected MethodHandle provisionHandle(MethodHandle providerHandle) {
150153
}
151154
// null check the result using the dependency.
152155
invokeProvider = InternalMethodHandles.nullCheckResult(invokeProvider, source);
156+
invokeProvider = validateReturnTypeHandle(invokeProvider);
153157
return invokeProvider;
154158
}
155159

160+
protected MethodHandle validateReturnTypeHandle(MethodHandle resultHandle) {
161+
return MethodHandles.filterReturnValue(
162+
resultHandle,
163+
MethodHandles.insertArguments(CHECK_SUBTYPE_NOT_PROVIDED_MH, 1, source, providedRawType));
164+
}
165+
166+
private static final MethodHandle CHECK_SUBTYPE_NOT_PROVIDED_MH =
167+
InternalMethodHandles.findStaticOrDie(
168+
ProviderInternalFactory.class,
169+
"doCheckSubtypeNotProvided",
170+
methodType(Object.class, Object.class, Object.class, Class.class));
171+
172+
@Keep
173+
static Object doCheckSubtypeNotProvided(Object result, Object source, Class<?> providedType)
174+
throws InternalProvisionException {
175+
if (result != null && !providedType.isInstance(result)) {
176+
// Historically this was surfaced as a ProvisionException embedding a
177+
// ClassCastException, so we keep that behavior. (Maybe one day we can
178+
// shift it to an explicit error without the inner ClassCastException.)
179+
throw InternalProvisionException.errorInProvider(
180+
new ClassCastException("Cannot cast " + result.getClass() + " to " + providedType))
181+
.addSource(source);
182+
}
183+
return result;
184+
}
185+
156186
/**
157187
* Provisions a new instance. Subclasses should override this to catch exceptions and rethrow as
158188
* ErrorsExceptions.
@@ -171,6 +201,18 @@ protected T provision(
171201
if (t == null && !dependency.isNullable()) {
172202
InternalProvisionException.onNullInjectedIntoNonNullableDependency(source, dependency);
173203
}
204+
validateReturnType(t);
174205
return t;
175206
}
207+
208+
protected void validateReturnType(T t) throws InternalProvisionException {
209+
if (t != null && !providedRawType.isInstance(t)) {
210+
// Historically this was surfaced as a ProvisionException embedding a
211+
// ClassCastException, so we keep that behavior. (Maybe one day we can
212+
// shift it to an explicit error without the inner ClassCastException.
213+
throw InternalProvisionException.errorInProvider(
214+
new ClassCastException("Cannot cast " + t.getClass() + " to " + providedRawType))
215+
.addSource(source);
216+
}
217+
}
176218
}

core/test/com/google/inject/ProvisionExceptionsTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,33 @@ Object providesDependsOnFailingProvider(String failing) {
326326
provideFailingValue); // the thing that failed.
327327
}
328328

329+
@Test
330+
@SuppressWarnings({"rawtypes", "unchecked"}) // Test requires incorrect types
331+
public void testConstructorBindingDependencyHasWrongType() {
332+
var module =
333+
new AbstractModule() {
334+
@Override
335+
protected void configure() {
336+
bind(String.class)
337+
.toProvider(
338+
new Provider() {
339+
@Override
340+
public Integer get() {
341+
return 1;
342+
}
343+
});
344+
}
345+
};
346+
Injector injector = Guice.createInjector(module);
347+
var pe = assertThrows(ProvisionException.class, () -> injector.getInstance(WantsString.class));
348+
assertThat(pe).hasCauseThat().isInstanceOf(ClassCastException.class);
349+
}
350+
351+
private static class WantsString {
352+
@Inject
353+
WantsString(String s) {}
354+
}
355+
329356
private static interface Exploder {}
330357

331358
public static class Explosion implements Exploder {

extensions/testlib/test/com/google/inject/testing/fieldbinder/BoundFieldModuleTest.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ public Object get() {
669669
}
670670

671671
@SuppressWarnings({"rawtypes", "JUnitIncompatibleType"}) // Testing rawtypes
672-
public void testRawProviderCanBindToIncorrectType() {
672+
public void testRawProviderCannotBindToIncorrectType() {
673673
final Integer testValue = 1024;
674674
Object instance =
675675
new Object() {
@@ -686,7 +686,12 @@ public Object get() {
686686
BoundFieldModule module = BoundFieldModule.of(instance);
687687
Injector injector = Guice.createInjector(module);
688688

689-
assertEquals(testValue, injector.getInstance(String.class));
689+
try {
690+
injector.getInstance(String.class);
691+
fail();
692+
} catch (ProvisionException e) {
693+
assertEquals(e.getCause().getClass(), ClassCastException.class);
694+
}
690695
}
691696

692697
public void testMultipleBindErrorsAreAggregated() {
@@ -1006,7 +1011,7 @@ public void testGetBoundFields_getField() throws Exception {
10061011

10071012
assertEquals(value, injector.getInstance(info.getBoundKey()));
10081013
}
1009-
1014+
10101015
public void testGetBoundFields_getKey() throws Exception {
10111016
Object instance =
10121017
new Object() {

0 commit comments

Comments
 (0)