diff --git a/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsFactoryImpl.java b/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsFactoryImpl.java index 03d8f78..16aeb25 100644 --- a/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsFactoryImpl.java +++ b/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsFactoryImpl.java @@ -30,21 +30,21 @@ public class MetricsStatsFactoryImpl implements MetricsStatsFactory { @Override public void postInit(KeycloakSessionFactory factory) { - if (!"true".equals(System.getenv().get("KC_METRICS_STATS_ENABLED"))) { + if (!"true".equals(getenv("KC_METRICS_STATS_ENABLED"))) { log.infov("Keycloak stats not enabled."); return; } var intervalDuration = Optional - .ofNullable(System.getenv("KC_METRICS_STATS_INTERVAL")) + .ofNullable(getenv("KC_METRICS_STATS_INTERVAL")) .map(Duration::parse) .orElse(Duration.ofSeconds(60)); var infoThreshold = Optional - .ofNullable(System.getenv("KC_METRICS_STATS_INFO_THRESHOLD")) + .ofNullable(getenv("KC_METRICS_STATS_INFO_THRESHOLD")) .map(Duration::parse) .orElse(Duration.ofMillis(Double.valueOf(intervalDuration.toMillis() * 0.5).longValue())); var warnThreshold = Optional - .ofNullable(System.getenv("KC_METRICS_STATS_WARN_THRESHOLD")) + .ofNullable(getenv("KC_METRICS_STATS_WARN_THRESHOLD")) .map(Duration::parse) .orElse(Duration.ofMillis(Double.valueOf(intervalDuration.toMillis() * 0.75).longValue())); log.infov("Keycloak stats enabled with interval of {0} and info/warn after {1}/{2}.", @@ -64,4 +64,8 @@ public class MetricsStatsFactoryImpl implements MetricsStatsFactory { @Override public void close() {} + + String getenv(String key) { + return System.getenv().get(key); + } } diff --git a/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsTask.java b/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsTask.java index 22f7c3c..eabfe2f 100644 --- a/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsTask.java +++ b/src/main/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsTask.java @@ -43,7 +43,7 @@ public class MetricsStatsTask implements Provider, ScheduledTask { scrape(session); } catch (Exception e) { if (e instanceof org.hibernate.exception.SQLGrammarException) { - log.infov("Metrics status task skipped, database not ready"); + log.infov("Metrics status task skipped, database not ready."); } else { log.errorv(e, "Failed to scrape stats."); } @@ -52,13 +52,13 @@ public class MetricsStatsTask implements Provider, ScheduledTask { var duration = Duration.between(start, Instant.now()); if (duration.compareTo(interval) > 0) { - log.errorv("Finished scrapping keycloak stats in {0}, consider to increase interval", duration); + log.errorv("Finished scrapping keycloak stats in {0}, consider to increase interval.", duration); } else if (duration.compareTo(warnThreshold) > 0) { - log.warnv("Finished scrapping keycloak stats in {0}, consider to increase interval", duration); + log.warnv("Finished scrapping keycloak stats in {0}, consider to increase interval.", duration); } else if (duration.compareTo(infoThreshold) > 0) { - log.infov("Finished scrapping keycloak stats in {0}", duration); + log.infov("Finished scrapping keycloak stats in {0}.", duration); } else { - log.debugv("Finished scrapping keycloak stats in {0}", duration); + log.debugv("Finished scrapping keycloak stats in {0}.", duration); } } diff --git a/src/test/java/io/kokuwa/keycloak/metrics/junit/AbstractMockitoTest.java b/src/test/java/io/kokuwa/keycloak/metrics/junit/AbstractMockitoTest.java index e42de7c..1a3d597 100644 --- a/src/test/java/io/kokuwa/keycloak/metrics/junit/AbstractMockitoTest.java +++ b/src/test/java/io/kokuwa/keycloak/metrics/junit/AbstractMockitoTest.java @@ -1,5 +1,14 @@ package io.kokuwa.keycloak.metrics.junit; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.ClassOrderer; import org.junit.jupiter.api.MethodOrderer; @@ -21,9 +30,42 @@ import io.micrometer.core.instrument.simple.SimpleMeterRegistry; @TestMethodOrder(MethodOrderer.DisplayName.class) public abstract class AbstractMockitoTest { + private static final List LOGS = new ArrayList<>(); + + static { + + System.setProperty("org.jboss.logging.provider", "jdk"); + System.setProperty("java.util.logging.SimpleFormatter.format", "%1$tT %4$-5s %2$s %5$s%6$s%n"); + + Logger.getLogger("org.junit").setLevel(Level.INFO); + Logger.getLogger("").setLevel(Level.ALL); + Logger.getLogger("").addHandler(new Handler() { + + @Override + public void publish(LogRecord log) { + LOGS.add(log); + } + + @Override + public void flush() {} + + @Override + public void close() {} + }); + } + @BeforeEach void reset() { Metrics.globalRegistry.clear(); Metrics.addRegistry(new SimpleMeterRegistry()); + LOGS.clear(); + } + + public static void assertLog(Level level, String message) { + assertTrue(LOGS.stream() + .filter(l -> l.getLevel().equals(level)) + .filter(l -> l.getMessage().equals(message)) + .findAny().isPresent(), + "log with level " + level + " and message " + message + " not found"); } } diff --git a/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsFactoryTest.java b/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsFactoryTest.java new file mode 100644 index 0000000..1a8a794 --- /dev/null +++ b/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsFactoryTest.java @@ -0,0 +1,107 @@ +package io.kokuwa.keycloak.metrics.stats; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Duration; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.platform.commons.util.ReflectionUtils; +import org.keycloak.models.KeycloakSession; +import org.keycloak.models.KeycloakSessionFactory; +import org.keycloak.models.KeycloakTransactionManager; +import org.keycloak.timer.TimerProvider; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.Spy; + +import io.kokuwa.keycloak.metrics.junit.AbstractMockitoTest; + +/** + * Test for {@link MetricsStatsFactory} with Mockito. + * + * @author Stephan Schnabel + */ +@DisplayName("metrics: factory") +public class MetricsStatsFactoryTest extends AbstractMockitoTest { + + @Spy + MetricsStatsFactoryImpl factory; + @Mock + KeycloakSessionFactory sessionFactory; + @Mock + KeycloakSession session; + + @DisplayName("disabled") + @Test + void disabled() { + factory.init(null); + factory.postInit(sessionFactory); + assertNull(factory.create(session)); + factory.close(); + } + + @DisplayName("enabled - with default values") + @Test + void enabledDefault() { + when(factory.getenv("KC_METRICS_STATS_ENABLED")).thenReturn("true"); + when(factory.getenv("KC_METRICS_STATS_INTERVAL")).thenReturn(null); + when(factory.getenv("KC_METRICS_STATS_INFO_THRESHOLD")).thenReturn(null); + when(factory.getenv("KC_METRICS_STATS_WARN_THRESHOLD")).thenReturn(null); + assertTask(Duration.ofSeconds(60), Duration.ofSeconds(30), Duration.ofSeconds(45)); + } + + @DisplayName("enabled - with custom interval") + @Test + void enabledCustomInterval() { + when(factory.getenv("KC_METRICS_STATS_ENABLED")).thenReturn("true"); + when(factory.getenv("KC_METRICS_STATS_INTERVAL")).thenReturn("PT300s"); + when(factory.getenv("KC_METRICS_STATS_INFO_THRESHOLD")).thenReturn(null); + when(factory.getenv("KC_METRICS_STATS_WARN_THRESHOLD")).thenReturn(null); + assertTask(Duration.ofSeconds(300), Duration.ofSeconds(150), Duration.ofSeconds(225)); + } + + @DisplayName("enabled - with custom thresholds") + @Test + void enabledCustomThresholds() { + when(factory.getenv("KC_METRICS_STATS_ENABLED")).thenReturn("true"); + when(factory.getenv("KC_METRICS_STATS_INTERVAL")).thenReturn(null); + when(factory.getenv("KC_METRICS_STATS_INFO_THRESHOLD")).thenReturn("PT40s"); + when(factory.getenv("KC_METRICS_STATS_WARN_THRESHOLD")).thenReturn("PT50s"); + assertTask(Duration.ofSeconds(60), Duration.ofSeconds(40), Duration.ofSeconds(50)); + } + + private void assertTask(Duration interval, Duration infoThreshold, Duration warnThreshold) { + + var timerProvider = mock(TimerProvider.class); + when(sessionFactory.create()).thenReturn(session); + when(session.getProvider(TimerProvider.class)).thenReturn(timerProvider); + when(session.getTransactionManager()).thenReturn(mock(KeycloakTransactionManager.class)); + + factory.postInit(sessionFactory); + + var taskCaptor = ArgumentCaptor.forClass(MetricsStatsTask.class); + verify(timerProvider).scheduleTask( + taskCaptor.capture(), + ArgumentMatchers.eq(interval.toMillis()), + ArgumentMatchers.eq("metrics")); + assertNotNull(taskCaptor.getValue(), "task"); + assertField(interval, taskCaptor.getValue(), "interval"); + assertField(infoThreshold, taskCaptor.getValue(), "infoThreshold"); + assertField(warnThreshold, taskCaptor.getValue(), "warnThreshold"); + } + + private void assertField(Duration expected, MetricsStatsTask task, String name) { + assertEquals( + expected, + assertDoesNotThrow(() -> ReflectionUtils.tryToReadFieldValue(MetricsStatsTask.class, name, task).get()), + "field " + name + " invalid"); + } +} diff --git a/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsSpiTest.java b/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsSpiTest.java new file mode 100644 index 0000000..e318262 --- /dev/null +++ b/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsSpiTest.java @@ -0,0 +1,37 @@ +package io.kokuwa.keycloak.metrics.stats; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ServiceLoader; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +import io.kokuwa.keycloak.metrics.junit.AbstractMockitoTest; + +/** + * Test for {@link MetricsStatsSpi} with Mockito. + * + * @author Stephan Schnabel + */ +@DisplayName("metrics: spi") +public class MetricsStatsSpiTest extends AbstractMockitoTest { + + @Test + void test() { + + var spi = new MetricsStatsSpi(); + assertEquals("metrics", spi.getName(), "getName()"); + assertFalse(spi.isInternal(), "isInternal()"); + assertNotNull(spi.getProviderClass(), "getProviderClass()"); + assertTrue(spi.getProviderFactoryClass().isInterface(), "getProviderFactoryClass() - should be an interface"); + + var factory = ServiceLoader.load(spi.getProviderFactoryClass()).findFirst().orElse(null); + assertNotNull(factory, "failed to read factory with service loader"); + assertEquals(MetricsStatsFactoryImpl.class, factory.getClass(), "factory.class"); + assertEquals("default", factory.getId(), "factory.id"); + } +} diff --git a/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsTaskTest.java b/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsTaskTest.java new file mode 100644 index 0000000..7e46b8e --- /dev/null +++ b/src/test/java/io/kokuwa/keycloak/metrics/stats/MetricsStatsTaskTest.java @@ -0,0 +1,225 @@ +package io.kokuwa.keycloak.metrics.stats; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.logging.Level; +import java.util.stream.Stream; + +import org.hibernate.exception.SQLGrammarException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.keycloak.models.ClientModel; +import org.keycloak.models.ClientProvider; +import org.keycloak.models.KeycloakSession; +import org.keycloak.models.RealmModel; +import org.keycloak.models.RealmProvider; +import org.keycloak.models.UserProvider; +import org.keycloak.models.UserSessionProvider; +import org.mockito.Mock; + +import io.kokuwa.keycloak.metrics.junit.AbstractMockitoTest; +import io.micrometer.core.instrument.Gauge; +import io.micrometer.core.instrument.Metrics; + +/** + * Test for {@link MetricsStatsTask} with Mockito. + * + * @author Stephan Schnabel + */ +@DisplayName("metrics: task") +public class MetricsStatsTaskTest extends AbstractMockitoTest { + + @Mock + KeycloakSession session; + @Mock + RealmProvider realmProvider; + @Mock + UserProvider userProvider; + @Mock + UserSessionProvider sessionProvider; + @Mock + ClientProvider clientProvider; + + @BeforeEach + void setup() { + when(session.realms()).thenReturn(realmProvider); + } + + @DisplayName("catch - nullpointer") + @Test + void catchNPE() { + when(session.realms()).thenThrow(NullPointerException.class); + task().run(session); + assertLog(Level.SEVERE, "Failed to scrape stats."); + } + + @DisplayName("catch - database") + @Test + void catchDatabase() { + when(session.realms()).thenThrow(SQLGrammarException.class); + task().run(session); + assertLog(Level.INFO, "Metrics status task skipped, database not ready."); + } + + @DisplayName("log - debug") + @Test + void logDebug() { + when(realmProvider.getRealmsStream()).thenReturn(Stream.of()); + task(Duration.ofMillis(300), Duration.ofMillis(100), Duration.ofMillis(200)).run(session); + assertLog(Level.FINE, "Finished scrapping keycloak stats in {0}."); + } + + @DisplayName("log - info") + @Test + void logInfo() { + when(realmProvider.getRealmsStream()).thenReturn(Stream.of()); + task(Duration.ofMillis(300), Duration.ZERO, Duration.ofMillis(200)).run(session); + assertLog(Level.INFO, "Finished scrapping keycloak stats in {0}."); + } + + @DisplayName("log - warn") + @Test + void logWarn() { + when(realmProvider.getRealmsStream()).thenReturn(Stream.of()); + task(Duration.ofMillis(300), Duration.ofMillis(100), Duration.ZERO).run(session); + assertLog(Level.WARNING, "Finished scrapping keycloak stats in {0}, consider to increase interval."); + } + + @DisplayName("log - error") + @Test + void logError() { + when(realmProvider.getRealmsStream()).thenReturn(Stream.of()); + task(Duration.ZERO, Duration.ofMillis(100), Duration.ofMillis(200)).run(session); + assertLog(Level.SEVERE, "Finished scrapping keycloak stats in {0}, consider to increase interval."); + } + + @DisplayName("scrape") + @Test + void scrape() { + + var realm = UUID.randomUUID().toString(); + var realmModel = mock(RealmModel.class); + var client1 = UUID.randomUUID().toString(); + var client1Id = UUID.randomUUID().toString(); + var client1Model = mock(ClientModel.class); + var client2 = UUID.randomUUID().toString(); + var client2Id = UUID.randomUUID().toString(); + var client2Model = mock(ClientModel.class); + when(realmModel.getName()).thenReturn(realm); + when(realmModel.getClientsStream()).then(i -> Stream.of(client1Model, client2Model)); + when(client1Model.getId()).thenReturn(client1Id); + when(client1Model.getClientId()).thenReturn(client1); + when(client2Model.getId()).thenReturn(client2Id); + when(client2Model.getClientId()).thenReturn(client2); + + when(session.clients()).thenReturn(clientProvider); + when(session.users()).thenReturn(userProvider); + when(session.sessions()).thenReturn(sessionProvider); + when(realmProvider.getRealmsStream()).then(i -> Stream.of(realmModel)); + + // empty realm + + when(userProvider.getUsersCount(realmModel)).thenReturn(0); + when(clientProvider.getClientsCount(realmModel)).thenReturn(0L); + when(sessionProvider.getOfflineSessionsCount(realmModel, client1Model)).thenReturn(0L); + when(sessionProvider.getOfflineSessionsCount(realmModel, client2Model)).thenReturn(0L); + when(sessionProvider.getActiveUserSessions(realmModel, client1Model)).thenReturn(0L); + when(sessionProvider.getActiveUserSessions(realmModel, client2Model)).thenReturn(0L); + when(sessionProvider.getActiveClientSessionStats(realmModel, false)).thenReturn(Map.of()); + task().run(session); + assertUsersCount(realmModel, 0); + assertClientsCount(realmModel, 0); + assertOfflineSessions(realmModel, client1Model, 0); + assertOfflineSessions(realmModel, client2Model, 0); + assertActiveUserSessions(realmModel, client1Model, 0); + assertActiveUserSessions(realmModel, client2Model, 0); + assertActiveClientSessions(realmModel, client1Model, 0); + assertActiveClientSessions(realmModel, client2Model, 0); + + // initial values + + when(userProvider.getUsersCount(realmModel)).thenReturn(10); + when(clientProvider.getClientsCount(realmModel)).thenReturn(20L); + when(sessionProvider.getOfflineSessionsCount(realmModel, client1Model)).thenReturn(0L); + when(sessionProvider.getOfflineSessionsCount(realmModel, client2Model)).thenReturn(1L); + when(sessionProvider.getActiveUserSessions(realmModel, client1Model)).thenReturn(2L); + when(sessionProvider.getActiveUserSessions(realmModel, client2Model)).thenReturn(3L); + when(sessionProvider.getActiveClientSessionStats(realmModel, false)) + .thenReturn(Map.of(client1Id, 5L, client2Id, 0L)); + task().run(session); + assertUsersCount(realmModel, 10); + assertClientsCount(realmModel, 20); + assertOfflineSessions(realmModel, client1Model, 0); + assertOfflineSessions(realmModel, client2Model, 1); + assertActiveUserSessions(realmModel, client1Model, 2); + assertActiveUserSessions(realmModel, client2Model, 3); + assertActiveClientSessions(realmModel, client1Model, 5); + assertActiveClientSessions(realmModel, client2Model, 0); + + // updated values + + when(userProvider.getUsersCount(realmModel)).thenReturn(11); + when(clientProvider.getClientsCount(realmModel)).thenReturn(19L); + when(sessionProvider.getOfflineSessionsCount(realmModel, client1Model)).thenReturn(3L); + when(sessionProvider.getOfflineSessionsCount(realmModel, client2Model)).thenReturn(2L); + when(sessionProvider.getActiveUserSessions(realmModel, client1Model)).thenReturn(1L); + when(sessionProvider.getActiveUserSessions(realmModel, client2Model)).thenReturn(0L); + when(sessionProvider.getActiveClientSessionStats(realmModel, false)) + .thenReturn(Map.of(client1Id, 4L, client2Id, 3L)); + task().run(session); + assertUsersCount(realmModel, 11); + assertClientsCount(realmModel, 19); + assertOfflineSessions(realmModel, client1Model, 3); + assertOfflineSessions(realmModel, client2Model, 2); + assertActiveUserSessions(realmModel, client1Model, 1); + assertActiveUserSessions(realmModel, client2Model, 0); + assertActiveClientSessions(realmModel, client1Model, 4); + assertActiveClientSessions(realmModel, client2Model, 3); + } + + private MetricsStatsTask task() { + return task(Duration.ofMillis(300), Duration.ofMillis(100), Duration.ofMillis(200)); + } + + private MetricsStatsTask task(Duration interval, Duration infoThreshold, Duration warnThreshold) { + return new MetricsStatsTask(interval, infoThreshold, warnThreshold); + } + + private static void assertUsersCount(RealmModel realm, int count) { + assertGauge("keycloak_users", realm, null, count); + } + + private static void assertClientsCount(RealmModel realm, int count) { + assertGauge("keycloak_clients", realm, null, count); + } + + private static void assertActiveClientSessions(RealmModel realm, ClientModel client, int count) { + assertGauge("keycloak_active_client_sessions", realm, client, count); + } + + private static void assertActiveUserSessions(RealmModel realm, ClientModel client, int count) { + assertGauge("keycloak_active_user_sessions", realm, client, count); + } + + private static void assertOfflineSessions(RealmModel realm, ClientModel client, int count) { + assertGauge("keycloak_offline_sessions", realm, client, count); + } + + private static void assertGauge(String name, RealmModel realm, ClientModel client, int count) { + var gauges = Metrics.globalRegistry.getMeters().stream() + .filter(Gauge.class::isInstance) + .filter(gauge -> gauge.getId().getName().equals(name)) + .filter(gauge -> gauge.getId().getTag("realm").equals(realm.getName())) + .filter(gauge -> client == null || gauge.getId().getTag("client").equals(client.getClientId())) + .map(Gauge.class::cast) + .toList(); + assertEquals(1, gauges.size()); + assertEquals(count, gauges.get(0).value()); + } +}