diff --git a/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseEngine.java b/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseEngine.java new file mode 100644 index 0000000..ea1277d --- /dev/null +++ b/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseEngine.java @@ -0,0 +1,48 @@ +package io.github.rieske.dbtest.extension; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.function.Consumer; + +abstract class DatabaseEngine { + private volatile boolean templateDatabaseMigrated = false; + + void ensureTemplateDatabaseMigrated(Consumer migrator) { + if (!templateDatabaseMigrated) { + synchronized (this) { + if (!templateDatabaseMigrated) { + migrateTemplateDatabase(migrator, dataSourceForDatabase(getTemplateDatabaseName())); + templateDatabaseMigrated = true; + } + } + } + } + + void createDatabase(String databaseName) { + executePrivileged("CREATE DATABASE " + databaseName); + } + + void dropDatabase(String databaseName) { + executePrivileged("DROP DATABASE " + databaseName); + } + + abstract void cloneTemplateDatabaseTo(String targetDatabaseName); + + void executePrivileged(String sql) { + DataSource dataSource = getPrivilegedDataSource(); + try (Connection conn = dataSource.getConnection()) { + conn.createStatement().execute(sql); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + abstract DataSource dataSourceForDatabase(String databaseName); + + abstract String getTemplateDatabaseName(); + + abstract DataSource getPrivilegedDataSource(); + + abstract void migrateTemplateDatabase(Consumer migrator, DataSource templateDataSource); +} diff --git a/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseState.java b/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseState.java index b572c1b..8d432fc 100644 --- a/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseState.java +++ b/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseState.java @@ -1,61 +1,112 @@ package io.github.rieske.dbtest.extension; +import javax.sql.DataSource; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; +import java.util.function.BiConsumer; -interface DatabaseState { - String ensureDatabaseCreated(Class testClass, Consumer databaseCreator); +abstract class DatabaseState { + protected final DatabaseEngine database; - static String newDatabaseName() { + DatabaseState(DatabaseEngine database) { + this.database = database; + } + + abstract String ensureDatabaseCreated(String databaseName, Class testClass, BiConsumer databaseCreator); + abstract void afterTestMethod(String databaseName); + + private static String newDatabaseName() { return "testdb_" + UUID.randomUUID().toString().replace('-', '_'); } - class PerMethod implements DatabaseState { + DataSource dataSourceForDatabase(String databaseName) { + return database.dataSourceForDatabase(databaseName); + } + + static class PerMethod extends DatabaseState { + PerMethod(DatabaseEngine database) { + super(database); + } + @Override - public String ensureDatabaseCreated(Class testClass, Consumer databaseCreator) { - String databaseName = newDatabaseName(); - databaseCreator.accept(databaseName); + String ensureDatabaseCreated(String databaseName, Class testClass, BiConsumer databaseCreator) { + if (databaseName != null) { + return databaseName; + } + databaseName = newDatabaseName(); + databaseCreator.accept(database, databaseName); return databaseName; } + + @Override + void afterTestMethod(String databaseName) { + if (databaseName != null) { + database.dropDatabase(databaseName); + } + } } - class PerClass implements DatabaseState { + static class PerClass extends DatabaseState { private final Map, String> perClassDatabases = new ConcurrentHashMap<>(); + PerClass(DatabaseEngine database) { + super(database); + } + @Override - public String ensureDatabaseCreated(Class testClass, Consumer databaseCreator) { + String ensureDatabaseCreated(String databaseName, Class testClass, BiConsumer databaseCreator) { + if (testClass == null) { + throw new IllegalStateException("Per-class database extension must be registered as a static test class field in order to use the datasource during test instance construction."); + } + if (databaseName != null && databaseName.equals(perClassDatabases.get(testClass))) { + return databaseName; + } String perClassDatabaseName = perClassDatabases.get(testClass); if (perClassDatabaseName == null) { synchronized (perClassDatabases) { perClassDatabaseName = perClassDatabases.get(testClass); if (perClassDatabaseName == null) { perClassDatabaseName = newDatabaseName(); - databaseCreator.accept(perClassDatabaseName); + databaseCreator.accept(database, perClassDatabaseName); perClassDatabases.put(testClass, perClassDatabaseName); } } } return perClassDatabaseName; } + + @Override + void afterTestMethod(String databaseName) { + } } - class PerExecution implements DatabaseState { + static class PerExecution extends DatabaseState { private final String perExecutionDatabaseName = newDatabaseName(); private volatile boolean perExecutionDatabaseCreated = false; + PerExecution(DatabaseEngine database) { + super(database); + } + @Override - public String ensureDatabaseCreated(Class testClass, Consumer databaseCreator) { + String ensureDatabaseCreated(String databaseName, Class testClass, BiConsumer databaseCreator) { + if (perExecutionDatabaseName.equals(databaseName)) { + return databaseName; + } if (!perExecutionDatabaseCreated) { synchronized (perExecutionDatabaseName) { if (!perExecutionDatabaseCreated) { - databaseCreator.accept(perExecutionDatabaseName); + databaseCreator.accept(database, perExecutionDatabaseName); perExecutionDatabaseCreated = true; } } } return perExecutionDatabaseName; } + + @Override + void afterTestMethod(String databaseName) { + } } } diff --git a/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseTestExtension.java b/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseTestExtension.java index 07584b5..45f04fd 100644 --- a/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseTestExtension.java +++ b/base/src/main/java/io/github/rieske/dbtest/extension/DatabaseTestExtension.java @@ -4,15 +4,19 @@ import org.junit.jupiter.api.extension.BeforeEachCallback; import org.junit.jupiter.api.extension.Extension; import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.InvocationInterceptor; +import org.junit.jupiter.api.extension.ReflectiveInvocationContext; import javax.sql.DataSource; +import java.lang.reflect.Constructor; +import java.util.function.BiConsumer; import java.util.function.Consumer; /** * Base class for concrete database test extension implementations. * Encapsulates the core extension behavior that does not rely on backing database specifics. */ -public abstract class DatabaseTestExtension implements Extension, BeforeEachCallback, AfterEachCallback { +public abstract class DatabaseTestExtension implements Extension, BeforeEachCallback, AfterEachCallback, InvocationInterceptor { /** * Extension execution mode. Defines the database state guarantees for test executions. @@ -29,44 +33,38 @@ public enum Mode { /** * Create a single database per JVM process. Any state written by tests will be visible to other tests. */ - DATABASE_PER_EXECUTION; - - private String ensureDatabaseCreated(TestDatabase database, Consumer databaseCreator, Class testClass) { - switch (this) { - case DATABASE_PER_TEST_METHOD: - return database.perMethod.ensureDatabaseCreated(testClass, databaseCreator); - case DATABASE_PER_TEST_CLASS: - return database.perClass.ensureDatabaseCreated(testClass, databaseCreator); - case DATABASE_PER_EXECUTION: - return database.perExecution.ensureDatabaseCreated(testClass, databaseCreator); - default: - throw new IllegalStateException("No strategy exists for " + this + " mode"); - } - } + DATABASE_PER_EXECUTION } - private final TestDatabase database; - private final Consumer databaseCreator; - private final Mode mode; + private final DatabaseState databaseState; + private final BiConsumer databaseCreator; + private Class testClass; private String databaseName; DatabaseTestExtension(TestDatabase database, Mode mode, boolean migrateOnce) { - this.database = database; - this.databaseCreator = makeDatabaseCreator(database, migrateOnce); - this.mode = mode; + this.databaseState = database.getState(mode); + this.databaseCreator = makeDatabaseCreator(migrateOnce, this::migrateDatabase); + } + + @Override + public T interceptTestClassConstructor( + Invocation invocation, + ReflectiveInvocationContext> invocationContext, + ExtensionContext extensionContext + ) throws Throwable { + this.testClass = extensionContext.getRequiredTestClass(); + return invocation.proceed(); } @Override public void beforeEach(ExtensionContext context) { - this.databaseName = mode.ensureDatabaseCreated(database, databaseCreator, context.getRequiredTestClass()); + this.testClass = context.getRequiredTestClass(); } @Override public void afterEach(ExtensionContext context) { - if (mode == Mode.DATABASE_PER_TEST_METHOD) { - database.dropDatabase(databaseName); - } + databaseState.afterTestMethod(databaseName); } /** @@ -76,7 +74,8 @@ public void afterEach(ExtensionContext context) { * @return dataSource for a migrated database */ public DataSource getDataSource() { - return database.dataSourceForDatabase(databaseName); + this.databaseName = databaseState.ensureDatabaseCreated(databaseName, testClass, databaseCreator); + return databaseState.dataSourceForDatabase(databaseName); } /** @@ -87,16 +86,16 @@ public DataSource getDataSource() { */ abstract protected void migrateDatabase(DataSource dataSource); - private Consumer makeDatabaseCreator(TestDatabase database, boolean migrateOnce) { + private static BiConsumer makeDatabaseCreator(boolean migrateOnce, Consumer migrator) { if (migrateOnce) { - return databaseName -> { - database.ensureTemplateDatabaseMigrated(this::migrateDatabase); + return (database, databaseName) -> { + database.ensureTemplateDatabaseMigrated(migrator); database.cloneTemplateDatabaseTo(databaseName); }; } else { - return databaseName -> { + return (database, databaseName) -> { database.createDatabase(databaseName); - migrateDatabase(database.dataSourceForDatabase(databaseName)); + migrator.accept(database.dataSourceForDatabase(databaseName)); }; } } diff --git a/base/src/main/java/io/github/rieske/dbtest/extension/TestDatabase.java b/base/src/main/java/io/github/rieske/dbtest/extension/TestDatabase.java index 8fa1e32..cccee1e 100644 --- a/base/src/main/java/io/github/rieske/dbtest/extension/TestDatabase.java +++ b/base/src/main/java/io/github/rieske/dbtest/extension/TestDatabase.java @@ -1,52 +1,26 @@ package io.github.rieske.dbtest.extension; -import javax.sql.DataSource; -import java.sql.Connection; -import java.sql.SQLException; -import java.util.function.Consumer; - -abstract class TestDatabase { - final DatabaseState perMethod = new DatabaseState.PerMethod(); - final DatabaseState perClass = new DatabaseState.PerClass(); - final DatabaseState perExecution = new DatabaseState.PerExecution(); - - private volatile boolean templateDatabaseMigrated = false; - - void ensureTemplateDatabaseMigrated(Consumer migrator) { - if (!templateDatabaseMigrated) { - synchronized (this) { - if (!templateDatabaseMigrated) { - migrateTemplateDatabase(migrator, dataSourceForDatabase(getTemplateDatabaseName())); - templateDatabaseMigrated = true; - } - } - } +class TestDatabase { + private final DatabaseState perMethod; + private final DatabaseState perClass; + private final DatabaseState perExecution; + + TestDatabase(DatabaseEngine databaseEngine) { + this.perMethod = new DatabaseState.PerMethod(databaseEngine); + this.perClass = new DatabaseState.PerClass(databaseEngine); + this.perExecution = new DatabaseState.PerExecution(databaseEngine); } - void createDatabase(String databaseName) { - executePrivileged("CREATE DATABASE " + databaseName); - } - - void dropDatabase(String databaseName) { - executePrivileged("DROP DATABASE " + databaseName); - } - - abstract void cloneTemplateDatabaseTo(String targetDatabaseName); - - void executePrivileged(String sql) { - DataSource dataSource = getPrivilegedDataSource(); - try (Connection conn = dataSource.getConnection()) { - conn.createStatement().execute(sql); - } catch (SQLException e) { - throw new RuntimeException(e); + DatabaseState getState(DatabaseTestExtension.Mode mode) { + switch (mode) { + case DATABASE_PER_TEST_METHOD: + return perMethod; + case DATABASE_PER_TEST_CLASS: + return perClass; + case DATABASE_PER_EXECUTION: + return perExecution; + default: + throw new IllegalStateException("No database state strategy exists for " + this + " mode"); } } - - abstract DataSource dataSourceForDatabase(String databaseName); - - abstract String getTemplateDatabaseName(); - - abstract DataSource getPrivilegedDataSource(); - - abstract void migrateTemplateDatabase(Consumer migrator, DataSource templateDataSource); } diff --git a/base/src/testFixtures/java/io/github/rieske/dbtest/DatabaseTest.java b/base/src/testFixtures/java/io/github/rieske/dbtest/DatabaseTest.java index 3243cea..37d51d1 100644 --- a/base/src/testFixtures/java/io/github/rieske/dbtest/DatabaseTest.java +++ b/base/src/testFixtures/java/io/github/rieske/dbtest/DatabaseTest.java @@ -11,7 +11,7 @@ import static org.assertj.core.api.Assertions.assertThat; -abstract class DatabaseTest { +public abstract class DatabaseTest { @RegisterExtension private final DatabaseTestExtension database; diff --git a/base/src/testFixtures/java/io/github/rieske/dbtest/TestRepository.java b/base/src/testFixtures/java/io/github/rieske/dbtest/TestRepository.java new file mode 100644 index 0000000..c3a38f5 --- /dev/null +++ b/base/src/testFixtures/java/io/github/rieske/dbtest/TestRepository.java @@ -0,0 +1,50 @@ +package io.github.rieske.dbtest; + +import org.assertj.core.api.Assertions; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.UUID; + +public class TestRepository { + private final DataSource dataSource; + + public TestRepository(DataSource dataSource) { + this.dataSource = dataSource; + } + + public void insertRandomRecord() { + UUID id = UUID.randomUUID(); + String foo = UUID.randomUUID().toString(); + executeUpdateSql("INSERT INTO some_table(id, foo) VALUES('" + id + "', '" + foo + "')"); + } + + public int getRecordCount() { + return executeQuerySql("SELECT COUNT(*) FROM some_table", rs -> { + Assertions.assertThat(rs.next()).isTrue(); + return rs.getInt(1); + }); + } + + private void executeUpdateSql(String sql) { + try (Connection connection = dataSource.getConnection(); + Statement stmt = connection.createStatement()) { + stmt.executeUpdate(sql); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } + + private T executeQuerySql(String sql, DatabaseTest.ResultSetMapper resultSetMapper) { + try (Connection connection = dataSource.getConnection(); + Statement stmt = connection.createStatement(); + ResultSet rs = stmt.executeQuery(sql)) { + return resultSetMapper.map(rs); + } catch (SQLException e) { + throw new RuntimeException(e); + } + } +} diff --git a/gradle.properties b/gradle.properties index 0a4eaec..1dfef78 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,3 +1,3 @@ org.gradle.unsafe.configuration-cache=true group=io.github.rieske.dbtest -version=0.0.2 +version=0.0.3 diff --git a/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabase.java b/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabase.java index 863536e..7aee98a 100644 --- a/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabase.java +++ b/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabase.java @@ -9,7 +9,7 @@ import java.util.Map; import java.util.function.Consumer; -class MySQLTestDatabase extends TestDatabase { +class MySQLTestDatabase extends DatabaseEngine { private static final String DB_DUMP_FILENAME = "db_dump.sql"; private final MySQLContainer container; diff --git a/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabaseManager.java b/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabaseManager.java index 783d939..79fb76a 100644 --- a/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabaseManager.java +++ b/mysql/src/main/java/io/github/rieske/dbtest/extension/MySQLTestDatabaseManager.java @@ -4,9 +4,9 @@ import java.util.concurrent.ConcurrentHashMap; class MySQLTestDatabaseManager { - private static final Map DATABASES = new ConcurrentHashMap<>(); + private static final Map DATABASES = new ConcurrentHashMap<>(); - static MySQLTestDatabase getDatabase(String version) { - return DATABASES.computeIfAbsent(version, MySQLTestDatabase::new); + static TestDatabase getDatabase(String version) { + return DATABASES.computeIfAbsent(version, v -> new TestDatabase(new MySQLTestDatabase(v))); } } diff --git a/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabase.java b/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabase.java index 3b72cce..9f9762d 100644 --- a/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabase.java +++ b/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabase.java @@ -7,7 +7,7 @@ import java.util.Map; import java.util.function.Consumer; -class PostgreSQLTestDatabase extends TestDatabase { +class PostgreSQLTestDatabase extends DatabaseEngine { private final PostgreSQLContainer container; private final String jdbcPrefix; diff --git a/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabaseManager.java b/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabaseManager.java index cee3248..70f7358 100644 --- a/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabaseManager.java +++ b/postgresql/src/main/java/io/github/rieske/dbtest/extension/PostgreSQLTestDatabaseManager.java @@ -4,9 +4,9 @@ import java.util.concurrent.ConcurrentHashMap; class PostgreSQLTestDatabaseManager { - private static final Map DATABASES = new ConcurrentHashMap<>(); + private static final Map DATABASES = new ConcurrentHashMap<>(); - static PostgreSQLTestDatabase getDatabase(String version) { - return DATABASES.computeIfAbsent(version, PostgreSQLTestDatabase::new); + static TestDatabase getDatabase(String version) { + return DATABASES.computeIfAbsent(version, v -> new TestDatabase(new PostgreSQLTestDatabase(v))); } } diff --git a/postgresql/src/test/java/io/github/rieske/dbtest/PostgreSQLTest.java b/postgresql/src/test/java/io/github/rieske/dbtest/PostgreSQLTest.java index 981043f..6b0323f 100644 --- a/postgresql/src/test/java/io/github/rieske/dbtest/PostgreSQLTest.java +++ b/postgresql/src/test/java/io/github/rieske/dbtest/PostgreSQLTest.java @@ -4,8 +4,8 @@ import io.github.rieske.dbtest.extension.FlywayPostgreSQLFastTestExtension; import io.github.rieske.dbtest.extension.FlywayPostgreSQLSlowTestExtension; -interface PostgreSQLTest { - default String postgresVersion() { +public interface PostgreSQLTest { + static String postgresVersion() { return Environment.getEnvOrDefault("POSTGRES_VERSION", "15.2-alpine"); } diff --git a/postgresql/src/test/java/io/github/rieske/dbtest/lifecycle/FieldExtensionLifecycleTests.java b/postgresql/src/test/java/io/github/rieske/dbtest/lifecycle/FieldExtensionLifecycleTests.java new file mode 100644 index 0000000..32c62d9 --- /dev/null +++ b/postgresql/src/test/java/io/github/rieske/dbtest/lifecycle/FieldExtensionLifecycleTests.java @@ -0,0 +1,73 @@ +package io.github.rieske.dbtest.lifecycle; + +import io.github.rieske.dbtest.PostgreSQLTest; +import io.github.rieske.dbtest.TestRepository; +import io.github.rieske.dbtest.extension.DatabaseTestExtension; +import io.github.rieske.dbtest.extension.FlywayPostgreSQLFastTestExtension; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import static org.assertj.core.api.Assertions.assertThat; + +class FieldExtensionLifecycleTests { + + @Nested + class PerMethodTests { + @RegisterExtension + private final DatabaseTestExtension database = + new FlywayPostgreSQLFastTestExtension(PostgreSQLTest.postgresVersion(), DatabaseTestExtension.Mode.DATABASE_PER_TEST_METHOD); + + private final TestRepository repository = new TestRepository(database.getDataSource()); + + @Test + void interactWithDatabase() { + assertThat(repository.getRecordCount()).isZero(); + repository.insertRandomRecord(); + assertThat(repository.getRecordCount()).isOne(); + } + } + + @Nested + class PerClassTests { + @RegisterExtension + private final DatabaseTestExtension database = + new FlywayPostgreSQLFastTestExtension(PostgreSQLTest.postgresVersion(), DatabaseTestExtension.Mode.DATABASE_PER_TEST_CLASS); + + private final EarlyDatabaseExtensionUser exceptionSupplier = new EarlyDatabaseExtensionUser(database); + + @Test + void throwsWhenPerClassExtensionIsRegisteredAsInstanceFieldAndAccessedDuringTestInstantiation() { + assertThat(exceptionSupplier.expectedException).hasMessage("Per-class database extension must be registered as a static test class field in order to use the datasource during test instance construction."); + } + + class EarlyDatabaseExtensionUser { + private final IllegalStateException expectedException; + + EarlyDatabaseExtensionUser(DatabaseTestExtension extension) { + try { + extension.getDataSource(); + throw new RuntimeException("Expected an IllegalStateException to be thrown"); + } catch (IllegalStateException e) { + this.expectedException = e; + } + } + } + } + + @Nested + class PerExecutionTests { + @RegisterExtension + private final DatabaseTestExtension database = + new FlywayPostgreSQLFastTestExtension(PostgreSQLTest.postgresVersion(), DatabaseTestExtension.Mode.DATABASE_PER_EXECUTION); + + private final TestRepository repository = new TestRepository(database.getDataSource()); + + @Test + void interactWithDatabase() { + int recordCount = repository.getRecordCount(); + repository.insertRandomRecord(); + assertThat(repository.getRecordCount()).isGreaterThan(recordCount); + } + } +} diff --git a/postgresql/src/test/java/io/github/rieske/dbtest/lifecycle/PerClassStaticExtensionTest.java b/postgresql/src/test/java/io/github/rieske/dbtest/lifecycle/PerClassStaticExtensionTest.java new file mode 100644 index 0000000..7e21060 --- /dev/null +++ b/postgresql/src/test/java/io/github/rieske/dbtest/lifecycle/PerClassStaticExtensionTest.java @@ -0,0 +1,25 @@ +package io.github.rieske.dbtest.lifecycle; + +import io.github.rieske.dbtest.PostgreSQLTest; +import io.github.rieske.dbtest.TestRepository; +import io.github.rieske.dbtest.extension.DatabaseTestExtension; +import io.github.rieske.dbtest.extension.FlywayPostgreSQLFastTestExtension; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import static org.assertj.core.api.Assertions.assertThat; + +class PerClassStaticExtensionTest { + @RegisterExtension + private static final DatabaseTestExtension database = + new FlywayPostgreSQLFastTestExtension(PostgreSQLTest.postgresVersion(), DatabaseTestExtension.Mode.DATABASE_PER_TEST_CLASS); + + private final TestRepository repository = new TestRepository(database.getDataSource()); + + @Test + void interactWithDatabase() { + assertThat(repository.getRecordCount()).isZero(); + repository.insertRandomRecord(); + assertThat(repository.getRecordCount()).isOne(); + } +}