@@ -24,34 +24,29 @@ import au.com.dius.pact.core.model.annotations.PactFolder
24
24
import au.com.dius.pact.core.model.messaging.MessagePact
25
25
import au.com.dius.pact.core.support.Annotations
26
26
import au.com.dius.pact.core.support.BuiltToolConfig
27
- import au.com.dius.pact.core.support.Json
28
27
import au.com.dius.pact.core.support.MetricEvent
29
28
import au.com.dius.pact.core.support.Metrics
30
29
import au.com.dius.pact.core.support.expressions.DataType
31
30
import au.com.dius.pact.core.support.expressions.ExpressionParser
32
31
import au.com.dius.pact.core.support.isNotEmpty
33
32
import io.github.oshai.kotlinlogging.KLogging
33
+ import org.apache.hc.core5.util.ReflectionUtils
34
34
import org.junit.jupiter.api.Disabled
35
35
import org.junit.jupiter.api.Nested
36
- import org.junit.jupiter.api.extension.AfterAllCallback
37
- import org.junit.jupiter.api.extension.AfterTestExecutionCallback
38
- import org.junit.jupiter.api.extension.BeforeAllCallback
39
- import org.junit.jupiter.api.extension.BeforeTestExecutionCallback
40
- import org.junit.jupiter.api.extension.Extension
41
- import org.junit.jupiter.api.extension.ExtensionContext
42
- import org.junit.jupiter.api.extension.ParameterContext
43
- import org.junit.jupiter.api.extension.ParameterResolver
36
+ import org.junit.jupiter.api.TestTemplate
37
+ import org.junit.jupiter.api.extension.*
44
38
import org.junit.platform.commons.support.AnnotationSupport
45
39
import org.junit.platform.commons.support.HierarchyTraversalMode
46
40
import org.junit.platform.commons.support.ReflectionSupport
47
41
import org.junit.platform.commons.util.AnnotationUtils.isAnnotated
48
42
import java.lang.reflect.Method
49
43
import java.util.Optional
50
44
import java.util.concurrent.ConcurrentHashMap
45
+ import java.util.stream.Stream
51
46
import kotlin.reflect.full.findAnnotation
52
47
53
48
class PactConsumerTestExt : Extension , BeforeTestExecutionCallback , BeforeAllCallback , ParameterResolver ,
54
- AfterTestExecutionCallback , AfterAllCallback {
49
+ AfterTestExecutionCallback , AfterAllCallback , TestTemplateInvocationContextProvider {
55
50
56
51
private val ep: ExpressionParser = ExpressionParser ()
57
52
@@ -103,6 +98,21 @@ class PactConsumerTestExt : Extension, BeforeTestExecutionCallback, BeforeAllCal
103
98
return false
104
99
}
105
100
101
+ override fun supportsTestTemplate (extensionContext : ExtensionContext ): Boolean {
102
+ val testTemplate = extensionContext
103
+ .testClass.get()
104
+ .methods
105
+ .find { AnnotationSupport .isAnnotated(it, TestTemplate ::class .java) }
106
+
107
+ return testTemplate != null && testTemplate.parameters[0 ].type == AsynchronousMessageContext ::class .java
108
+ }
109
+
110
+ override fun provideTestTemplateInvocationContexts (extensionContext : ExtensionContext ): Stream <TestTemplateInvocationContext > {
111
+ val providerInfo = this .lookupProviderInfo(extensionContext)
112
+ val pact = setupPactForTest(providerInfo[0 ].first, providerInfo[0 ].second, extensionContext)
113
+ return pact.asV4Pact().unwrap().interactions.map { AsynchronousMessageContext (it.asAsynchronousMessage()!! ) }.stream() as Stream <TestTemplateInvocationContext >
114
+ }
115
+
106
116
override fun resolveParameter (parameterContext : ParameterContext , extensionContext : ExtensionContext ): Any {
107
117
val type = parameterContext.parameter.type
108
118
val providers = lookupProviderInfo(extensionContext)
@@ -259,36 +269,38 @@ class PactConsumerTestExt : Extension, BeforeTestExecutionCallback, BeforeAllCal
259
269
): BasePact {
260
270
val store = context.getStore(NAMESPACE )
261
271
val key = " pact:${providerInfo.providerName} "
272
+ var methods = pactMethods
273
+ if (methods.isEmpty()) {
274
+ methods = AnnotationSupport .findAnnotatedMethods(context.requiredTestClass, Pact ::class .java, HierarchyTraversalMode .TOP_DOWN )
275
+ .map { m -> m.name}
276
+ }
277
+
262
278
return when {
263
279
store[key] != null -> store[key] as BasePact
264
280
else -> {
265
- val pact = if (pactMethods.isEmpty()) {
266
- lookupPact(providerInfo, " " , context)
267
- } else {
268
- val head = pactMethods.first()
269
- val tail = pactMethods.drop(1 )
270
- val initial = lookupPact(providerInfo, head, context)
271
- tail.fold(initial) { acc, method ->
272
- val pact = lookupPact(providerInfo, method, context)
273
-
274
- if (pact.provider != acc.provider) {
275
- // Should not really get here, as the Pacts should have been sorted by provider
276
- throw IllegalArgumentException (" You are using different Pacts with different providers for the same test" +
277
- " ('${acc.provider} ') and '${pact.provider} '). A separate test (and ideally a separate test class)" +
278
- " should be used for each provider." )
279
- }
281
+ val head = methods.first()
282
+ val tail = methods.drop(1 )
283
+ val initial = lookupPact(providerInfo, head, context)
284
+ val pact = tail.fold(initial) { acc, method ->
285
+ val pact = lookupPact(providerInfo, method, context)
286
+
287
+ if (pact.provider != acc.provider) {
288
+ // Should not really get here, as the Pacts should have been sorted by provider
289
+ throw IllegalArgumentException (" You are using different Pacts with different providers for the same test" +
290
+ " ('${acc.provider} ') and '${pact.provider} '). A separate test (and ideally a separate test class)" +
291
+ " should be used for each provider." )
292
+ }
280
293
281
- if (pact.consumer != acc.consumer) {
282
- logger.warn {
283
- " WARNING: You are using different Pacts with different consumers for the same test " +
284
- " ('${acc.consumer} ') and '${pact.consumer} '). The second consumer will be ignored and dropped from " +
285
- " the Pact and the interactions merged. If this is not your intention, you need to create a " +
286
- " separate test for each consumer."
287
- }
294
+ if (pact.consumer != acc.consumer) {
295
+ logger.warn {
296
+ " WARNING: You are using different Pacts with different consumers for the same test " +
297
+ " ('${acc.consumer} ') and '${pact.consumer} '). The second consumer will be ignored and dropped from " +
298
+ " the Pact and the interactions merged. If this is not your intention, you need to create a " +
299
+ " separate test for each consumer."
288
300
}
289
-
290
- acc.mergeInteractions(pact.interactions) as BasePact
291
301
}
302
+
303
+ acc.mergeInteractions(pact.interactions) as BasePact
292
304
}
293
305
store.put(key, pact)
294
306
pact
@@ -541,7 +553,7 @@ class PactConsumerTestExt : Extension, BeforeTestExecutionCallback, BeforeAllCal
541
553
ProviderType .ASYNCH -> {
542
554
if (method.parameterTypes[0 ].isAssignableFrom(Class .forName(" au.com.dius.pact.consumer.MessagePactBuilder" ))) {
543
555
ReflectionSupport .invokeMethod(
544
- method, context.requiredTestInstance ,
556
+ method, context.testInstance ,
545
557
MessagePactBuilder (providerInfo.pactVersion ? : PactSpecVersion .V3 )
546
558
.consumer(pactConsumer).hasPactWith(providerNameToUse)
547
559
) as BasePact
@@ -550,7 +562,7 @@ class PactConsumerTestExt : Extension, BeforeTestExecutionCallback, BeforeAllCal
550
562
if (providerInfo.pactVersion != null ) {
551
563
pactBuilder.pactSpecVersion(providerInfo.pactVersion)
552
564
}
553
- ReflectionSupport .invokeMethod(method, context.requiredTestInstance , pactBuilder) as BasePact
565
+ ReflectionSupport .invokeMethod(method, context.testInstance , pactBuilder) as BasePact
554
566
}
555
567
}
556
568
ProviderType .SYNCH_MESSAGE -> {
0 commit comments