Add unit tests with mockito for metric stats (#38)

This commit is contained in:
Stephan Schnabel 2023-05-02 09:30:58 +02:00 committed by GitHub
parent cfce0dee45
commit 99a01e5c02
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 424 additions and 9 deletions

View file

@ -1,5 +1,14 @@
package io.kokuwa.keycloak.metrics.junit;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Handler;
import java.util.logging.Level;
import java.util.logging.LogRecord;
import java.util.logging.Logger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.ClassOrderer;
import org.junit.jupiter.api.MethodOrderer;
@ -21,9 +30,42 @@ import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
@TestMethodOrder(MethodOrderer.DisplayName.class)
public abstract class AbstractMockitoTest {
private static final List<LogRecord> LOGS = new ArrayList<>();
static {
System.setProperty("org.jboss.logging.provider", "jdk");
System.setProperty("java.util.logging.SimpleFormatter.format", "%1$tT %4$-5s %2$s %5$s%6$s%n");
Logger.getLogger("org.junit").setLevel(Level.INFO);
Logger.getLogger("").setLevel(Level.ALL);
Logger.getLogger("").addHandler(new Handler() {
@Override
public void publish(LogRecord log) {
LOGS.add(log);
}
@Override
public void flush() {}
@Override
public void close() {}
});
}
@BeforeEach
void reset() {
Metrics.globalRegistry.clear();
Metrics.addRegistry(new SimpleMeterRegistry());
LOGS.clear();
}
public static void assertLog(Level level, String message) {
assertTrue(LOGS.stream()
.filter(l -> l.getLevel().equals(level))
.filter(l -> l.getMessage().equals(message))
.findAny().isPresent(),
"log with level " + level + " and message " + message + " not found");
}
}

View file

@ -0,0 +1,107 @@
package io.kokuwa.keycloak.metrics.stats;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.time.Duration;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.platform.commons.util.ReflectionUtils;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakTransactionManager;
import org.keycloak.timer.TimerProvider;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mock;
import org.mockito.Spy;
import io.kokuwa.keycloak.metrics.junit.AbstractMockitoTest;
/**
* Test for {@link MetricsStatsFactory} with Mockito.
*
* @author Stephan Schnabel
*/
@DisplayName("metrics: factory")
public class MetricsStatsFactoryTest extends AbstractMockitoTest {
@Spy
MetricsStatsFactoryImpl factory;
@Mock
KeycloakSessionFactory sessionFactory;
@Mock
KeycloakSession session;
@DisplayName("disabled")
@Test
void disabled() {
factory.init(null);
factory.postInit(sessionFactory);
assertNull(factory.create(session));
factory.close();
}
@DisplayName("enabled - with default values")
@Test
void enabledDefault() {
when(factory.getenv("KC_METRICS_STATS_ENABLED")).thenReturn("true");
when(factory.getenv("KC_METRICS_STATS_INTERVAL")).thenReturn(null);
when(factory.getenv("KC_METRICS_STATS_INFO_THRESHOLD")).thenReturn(null);
when(factory.getenv("KC_METRICS_STATS_WARN_THRESHOLD")).thenReturn(null);
assertTask(Duration.ofSeconds(60), Duration.ofSeconds(30), Duration.ofSeconds(45));
}
@DisplayName("enabled - with custom interval")
@Test
void enabledCustomInterval() {
when(factory.getenv("KC_METRICS_STATS_ENABLED")).thenReturn("true");
when(factory.getenv("KC_METRICS_STATS_INTERVAL")).thenReturn("PT300s");
when(factory.getenv("KC_METRICS_STATS_INFO_THRESHOLD")).thenReturn(null);
when(factory.getenv("KC_METRICS_STATS_WARN_THRESHOLD")).thenReturn(null);
assertTask(Duration.ofSeconds(300), Duration.ofSeconds(150), Duration.ofSeconds(225));
}
@DisplayName("enabled - with custom thresholds")
@Test
void enabledCustomThresholds() {
when(factory.getenv("KC_METRICS_STATS_ENABLED")).thenReturn("true");
when(factory.getenv("KC_METRICS_STATS_INTERVAL")).thenReturn(null);
when(factory.getenv("KC_METRICS_STATS_INFO_THRESHOLD")).thenReturn("PT40s");
when(factory.getenv("KC_METRICS_STATS_WARN_THRESHOLD")).thenReturn("PT50s");
assertTask(Duration.ofSeconds(60), Duration.ofSeconds(40), Duration.ofSeconds(50));
}
private void assertTask(Duration interval, Duration infoThreshold, Duration warnThreshold) {
var timerProvider = mock(TimerProvider.class);
when(sessionFactory.create()).thenReturn(session);
when(session.getProvider(TimerProvider.class)).thenReturn(timerProvider);
when(session.getTransactionManager()).thenReturn(mock(KeycloakTransactionManager.class));
factory.postInit(sessionFactory);
var taskCaptor = ArgumentCaptor.forClass(MetricsStatsTask.class);
verify(timerProvider).scheduleTask(
taskCaptor.capture(),
ArgumentMatchers.eq(interval.toMillis()),
ArgumentMatchers.eq("metrics"));
assertNotNull(taskCaptor.getValue(), "task");
assertField(interval, taskCaptor.getValue(), "interval");
assertField(infoThreshold, taskCaptor.getValue(), "infoThreshold");
assertField(warnThreshold, taskCaptor.getValue(), "warnThreshold");
}
private void assertField(Duration expected, MetricsStatsTask task, String name) {
assertEquals(
expected,
assertDoesNotThrow(() -> ReflectionUtils.tryToReadFieldValue(MetricsStatsTask.class, name, task).get()),
"field " + name + " invalid");
}
}

View file

@ -0,0 +1,37 @@
package io.kokuwa.keycloak.metrics.stats;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.ServiceLoader;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import io.kokuwa.keycloak.metrics.junit.AbstractMockitoTest;
/**
* Test for {@link MetricsStatsSpi} with Mockito.
*
* @author Stephan Schnabel
*/
@DisplayName("metrics: spi")
public class MetricsStatsSpiTest extends AbstractMockitoTest {
@Test
void test() {
var spi = new MetricsStatsSpi();
assertEquals("metrics", spi.getName(), "getName()");
assertFalse(spi.isInternal(), "isInternal()");
assertNotNull(spi.getProviderClass(), "getProviderClass()");
assertTrue(spi.getProviderFactoryClass().isInterface(), "getProviderFactoryClass() - should be an interface");
var factory = ServiceLoader.load(spi.getProviderFactoryClass()).findFirst().orElse(null);
assertNotNull(factory, "failed to read factory with service loader");
assertEquals(MetricsStatsFactoryImpl.class, factory.getClass(), "factory.class");
assertEquals("default", factory.getId(), "factory.id");
}
}

View file

@ -0,0 +1,225 @@
package io.kokuwa.keycloak.metrics.stats;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.logging.Level;
import java.util.stream.Stream;
import org.hibernate.exception.SQLGrammarException;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.keycloak.models.ClientModel;
import org.keycloak.models.ClientProvider;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.RealmProvider;
import org.keycloak.models.UserProvider;
import org.keycloak.models.UserSessionProvider;
import org.mockito.Mock;
import io.kokuwa.keycloak.metrics.junit.AbstractMockitoTest;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Metrics;
/**
* Test for {@link MetricsStatsTask} with Mockito.
*
* @author Stephan Schnabel
*/
@DisplayName("metrics: task")
public class MetricsStatsTaskTest extends AbstractMockitoTest {
@Mock
KeycloakSession session;
@Mock
RealmProvider realmProvider;
@Mock
UserProvider userProvider;
@Mock
UserSessionProvider sessionProvider;
@Mock
ClientProvider clientProvider;
@BeforeEach
void setup() {
when(session.realms()).thenReturn(realmProvider);
}
@DisplayName("catch - nullpointer")
@Test
void catchNPE() {
when(session.realms()).thenThrow(NullPointerException.class);
task().run(session);
assertLog(Level.SEVERE, "Failed to scrape stats.");
}
@DisplayName("catch - database")
@Test
void catchDatabase() {
when(session.realms()).thenThrow(SQLGrammarException.class);
task().run(session);
assertLog(Level.INFO, "Metrics status task skipped, database not ready.");
}
@DisplayName("log - debug")
@Test
void logDebug() {
when(realmProvider.getRealmsStream()).thenReturn(Stream.of());
task(Duration.ofMillis(300), Duration.ofMillis(100), Duration.ofMillis(200)).run(session);
assertLog(Level.FINE, "Finished scrapping keycloak stats in {0}.");
}
@DisplayName("log - info")
@Test
void logInfo() {
when(realmProvider.getRealmsStream()).thenReturn(Stream.of());
task(Duration.ofMillis(300), Duration.ZERO, Duration.ofMillis(200)).run(session);
assertLog(Level.INFO, "Finished scrapping keycloak stats in {0}.");
}
@DisplayName("log - warn")
@Test
void logWarn() {
when(realmProvider.getRealmsStream()).thenReturn(Stream.of());
task(Duration.ofMillis(300), Duration.ofMillis(100), Duration.ZERO).run(session);
assertLog(Level.WARNING, "Finished scrapping keycloak stats in {0}, consider to increase interval.");
}
@DisplayName("log - error")
@Test
void logError() {
when(realmProvider.getRealmsStream()).thenReturn(Stream.of());
task(Duration.ZERO, Duration.ofMillis(100), Duration.ofMillis(200)).run(session);
assertLog(Level.SEVERE, "Finished scrapping keycloak stats in {0}, consider to increase interval.");
}
@DisplayName("scrape")
@Test
void scrape() {
var realm = UUID.randomUUID().toString();
var realmModel = mock(RealmModel.class);
var client1 = UUID.randomUUID().toString();
var client1Id = UUID.randomUUID().toString();
var client1Model = mock(ClientModel.class);
var client2 = UUID.randomUUID().toString();
var client2Id = UUID.randomUUID().toString();
var client2Model = mock(ClientModel.class);
when(realmModel.getName()).thenReturn(realm);
when(realmModel.getClientsStream()).then(i -> Stream.of(client1Model, client2Model));
when(client1Model.getId()).thenReturn(client1Id);
when(client1Model.getClientId()).thenReturn(client1);
when(client2Model.getId()).thenReturn(client2Id);
when(client2Model.getClientId()).thenReturn(client2);
when(session.clients()).thenReturn(clientProvider);
when(session.users()).thenReturn(userProvider);
when(session.sessions()).thenReturn(sessionProvider);
when(realmProvider.getRealmsStream()).then(i -> Stream.of(realmModel));
// empty realm
when(userProvider.getUsersCount(realmModel)).thenReturn(0);
when(clientProvider.getClientsCount(realmModel)).thenReturn(0L);
when(sessionProvider.getOfflineSessionsCount(realmModel, client1Model)).thenReturn(0L);
when(sessionProvider.getOfflineSessionsCount(realmModel, client2Model)).thenReturn(0L);
when(sessionProvider.getActiveUserSessions(realmModel, client1Model)).thenReturn(0L);
when(sessionProvider.getActiveUserSessions(realmModel, client2Model)).thenReturn(0L);
when(sessionProvider.getActiveClientSessionStats(realmModel, false)).thenReturn(Map.of());
task().run(session);
assertUsersCount(realmModel, 0);
assertClientsCount(realmModel, 0);
assertOfflineSessions(realmModel, client1Model, 0);
assertOfflineSessions(realmModel, client2Model, 0);
assertActiveUserSessions(realmModel, client1Model, 0);
assertActiveUserSessions(realmModel, client2Model, 0);
assertActiveClientSessions(realmModel, client1Model, 0);
assertActiveClientSessions(realmModel, client2Model, 0);
// initial values
when(userProvider.getUsersCount(realmModel)).thenReturn(10);
when(clientProvider.getClientsCount(realmModel)).thenReturn(20L);
when(sessionProvider.getOfflineSessionsCount(realmModel, client1Model)).thenReturn(0L);
when(sessionProvider.getOfflineSessionsCount(realmModel, client2Model)).thenReturn(1L);
when(sessionProvider.getActiveUserSessions(realmModel, client1Model)).thenReturn(2L);
when(sessionProvider.getActiveUserSessions(realmModel, client2Model)).thenReturn(3L);
when(sessionProvider.getActiveClientSessionStats(realmModel, false))
.thenReturn(Map.of(client1Id, 5L, client2Id, 0L));
task().run(session);
assertUsersCount(realmModel, 10);
assertClientsCount(realmModel, 20);
assertOfflineSessions(realmModel, client1Model, 0);
assertOfflineSessions(realmModel, client2Model, 1);
assertActiveUserSessions(realmModel, client1Model, 2);
assertActiveUserSessions(realmModel, client2Model, 3);
assertActiveClientSessions(realmModel, client1Model, 5);
assertActiveClientSessions(realmModel, client2Model, 0);
// updated values
when(userProvider.getUsersCount(realmModel)).thenReturn(11);
when(clientProvider.getClientsCount(realmModel)).thenReturn(19L);
when(sessionProvider.getOfflineSessionsCount(realmModel, client1Model)).thenReturn(3L);
when(sessionProvider.getOfflineSessionsCount(realmModel, client2Model)).thenReturn(2L);
when(sessionProvider.getActiveUserSessions(realmModel, client1Model)).thenReturn(1L);
when(sessionProvider.getActiveUserSessions(realmModel, client2Model)).thenReturn(0L);
when(sessionProvider.getActiveClientSessionStats(realmModel, false))
.thenReturn(Map.of(client1Id, 4L, client2Id, 3L));
task().run(session);
assertUsersCount(realmModel, 11);
assertClientsCount(realmModel, 19);
assertOfflineSessions(realmModel, client1Model, 3);
assertOfflineSessions(realmModel, client2Model, 2);
assertActiveUserSessions(realmModel, client1Model, 1);
assertActiveUserSessions(realmModel, client2Model, 0);
assertActiveClientSessions(realmModel, client1Model, 4);
assertActiveClientSessions(realmModel, client2Model, 3);
}
private MetricsStatsTask task() {
return task(Duration.ofMillis(300), Duration.ofMillis(100), Duration.ofMillis(200));
}
private MetricsStatsTask task(Duration interval, Duration infoThreshold, Duration warnThreshold) {
return new MetricsStatsTask(interval, infoThreshold, warnThreshold);
}
private static void assertUsersCount(RealmModel realm, int count) {
assertGauge("keycloak_users", realm, null, count);
}
private static void assertClientsCount(RealmModel realm, int count) {
assertGauge("keycloak_clients", realm, null, count);
}
private static void assertActiveClientSessions(RealmModel realm, ClientModel client, int count) {
assertGauge("keycloak_active_client_sessions", realm, client, count);
}
private static void assertActiveUserSessions(RealmModel realm, ClientModel client, int count) {
assertGauge("keycloak_active_user_sessions", realm, client, count);
}
private static void assertOfflineSessions(RealmModel realm, ClientModel client, int count) {
assertGauge("keycloak_offline_sessions", realm, client, count);
}
private static void assertGauge(String name, RealmModel realm, ClientModel client, int count) {
var gauges = Metrics.globalRegistry.getMeters().stream()
.filter(Gauge.class::isInstance)
.filter(gauge -> gauge.getId().getName().equals(name))
.filter(gauge -> gauge.getId().getTag("realm").equals(realm.getName()))
.filter(gauge -> client == null || gauge.getId().getTag("client").equals(client.getClientId()))
.map(Gauge.class::cast)
.toList();
assertEquals(1, gauges.size());
assertEquals(count, gauges.get(0).value());
}
}