diff --git a/src/main/java/org/springframework/data/redis/cache/RedisCache.java b/src/main/java/org/springframework/data/redis/cache/RedisCache.java index b5567b9fbe..eaff14a2ad 100644 --- a/src/main/java/org/springframework/data/redis/cache/RedisCache.java +++ b/src/main/java/org/springframework/data/redis/cache/RedisCache.java @@ -34,6 +34,7 @@ import org.springframework.data.redis.util.ByteUtils; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ConcurrentReferenceHashMap; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; @@ -44,6 +45,7 @@ * * @author Christoph Strobl * @author Mark Paluch + * @author Piotr Mionskowski * @see RedisCacheConfiguration * @see RedisCacheWriter * @since 2.0 @@ -118,7 +120,7 @@ public RedisCacheWriter getNativeCache() { */ @Override @SuppressWarnings("unchecked") - public synchronized T get(Object key, Callable valueLoader) { + public T get(Object key, Callable valueLoader) { ValueWrapper result = get(key); @@ -126,9 +128,26 @@ public synchronized T get(Object key, Callable valueLoader) { return (T) result.get(); } - T value = valueFromLoader(key, valueLoader); - put(key, value); - return value; + return getSynchronized(key, valueLoader); + } + + private final ConcurrentReferenceHashMap valueLoaderLocks = new ConcurrentReferenceHashMap<>(); + + @SuppressWarnings({"unchecked", "SynchronizationOnLocalVariableOrMethodParameter"}) + private T getSynchronized(Object key, Callable valueLoader) { + final Object loaderLock = valueLoaderLocks.computeIfAbsent(createCacheKey(key), (String k) -> new Object()); + + synchronized (loaderLock) { + ValueWrapper result = get(key); + + if (result != null) { + return (T) result.get(); + } + + T value = valueFromLoader(key, valueLoader); + put(key, value); + return value; + } } /* diff --git a/src/test/java/org/springframework/data/redis/cache/RedisCacheTests.java b/src/test/java/org/springframework/data/redis/cache/RedisCacheTests.java index ef29b13782..ad4f90adc1 100644 --- a/src/test/java/org/springframework/data/redis/cache/RedisCacheTests.java +++ b/src/test/java/org/springframework/data/redis/cache/RedisCacheTests.java @@ -28,7 +28,12 @@ import java.util.Collection; import java.util.Collections; import java.util.Date; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; @@ -409,15 +414,47 @@ void cacheShouldFailOnNonConvertibleCacheKey() { assertThatExceptionOfType(IllegalStateException.class).isThrownBy(() -> cache.put(key, sample)); } - void doWithConnection(Consumer callback) { - RedisConnection connection = connectionFactory.getConnection(); - try { - callback.accept(connection); - } finally { - connection.close(); - } + @ParameterizedRedisTest // GH-2079 + void multipleThreadsLoadValueOnce() { + + int threadCount = 5; + + ConcurrentMap valuesByThreadId = new ConcurrentHashMap<>(threadCount); + + CountDownLatch waiter = new CountDownLatch(threadCount); + + AtomicInteger threadIds = new AtomicInteger(0); + + AtomicInteger currentValueForKey = new AtomicInteger(0); + + Stream.generate(threadIds::getAndIncrement) + .limit(threadCount) + .parallel() + .forEach((threadId) -> { + waiter.countDown(); + try { + waiter.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + Integer valueForThread = cache.get("key", currentValueForKey::incrementAndGet); + valuesByThreadId.put(threadId, valueForThread); + }); + + valuesByThreadId.forEach((thread, valueForThread) -> { + assertThat(valueForThread).isEqualTo(currentValueForKey.get()); + }); } + void doWithConnection(Consumer callback) { + RedisConnection connection = connectionFactory.getConnection(); + try { + callback.accept(connection); + } finally { + connection.close(); + } + } + @Data @NoArgsConstructor @AllArgsConstructor