Skip to content

Commit

Permalink
Add tests and support for database usage during test instance constru…
Browse files Browse the repository at this point in the history
…ction (#50)

* Add tests and support for database usage during test instance construction

* Extract DatabaseEngine
  • Loading branch information
rieske authored Feb 19, 2023
1 parent 1b3df93 commit cbe4641
Show file tree
Hide file tree
Showing 14 changed files with 322 additions and 102 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package io.github.rieske.dbtest.extension;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.function.Consumer;

abstract class DatabaseEngine {
private volatile boolean templateDatabaseMigrated = false;

void ensureTemplateDatabaseMigrated(Consumer<DataSource> migrator) {
if (!templateDatabaseMigrated) {
synchronized (this) {
if (!templateDatabaseMigrated) {
migrateTemplateDatabase(migrator, dataSourceForDatabase(getTemplateDatabaseName()));
templateDatabaseMigrated = true;
}
}
}
}

void createDatabase(String databaseName) {
executePrivileged("CREATE DATABASE " + databaseName);
}

void dropDatabase(String databaseName) {
executePrivileged("DROP DATABASE " + databaseName);
}

abstract void cloneTemplateDatabaseTo(String targetDatabaseName);

void executePrivileged(String sql) {
DataSource dataSource = getPrivilegedDataSource();
try (Connection conn = dataSource.getConnection()) {
conn.createStatement().execute(sql);
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

abstract DataSource dataSourceForDatabase(String databaseName);

abstract String getTemplateDatabaseName();

abstract DataSource getPrivilegedDataSource();

abstract void migrateTemplateDatabase(Consumer<DataSource> migrator, DataSource templateDataSource);
}
Original file line number Diff line number Diff line change
@@ -1,61 +1,112 @@
package io.github.rieske.dbtest.extension;

import javax.sql.DataSource;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.BiConsumer;

interface DatabaseState {
String ensureDatabaseCreated(Class<?> testClass, Consumer<String> databaseCreator);
abstract class DatabaseState {
protected final DatabaseEngine database;

static String newDatabaseName() {
DatabaseState(DatabaseEngine database) {
this.database = database;
}

abstract String ensureDatabaseCreated(String databaseName, Class<?> testClass, BiConsumer<DatabaseEngine, String> databaseCreator);
abstract void afterTestMethod(String databaseName);

private static String newDatabaseName() {
return "testdb_" + UUID.randomUUID().toString().replace('-', '_');
}

class PerMethod implements DatabaseState {
DataSource dataSourceForDatabase(String databaseName) {
return database.dataSourceForDatabase(databaseName);
}

static class PerMethod extends DatabaseState {
PerMethod(DatabaseEngine database) {
super(database);
}

@Override
public String ensureDatabaseCreated(Class<?> testClass, Consumer<String> databaseCreator) {
String databaseName = newDatabaseName();
databaseCreator.accept(databaseName);
String ensureDatabaseCreated(String databaseName, Class<?> testClass, BiConsumer<DatabaseEngine, String> databaseCreator) {
if (databaseName != null) {
return databaseName;
}
databaseName = newDatabaseName();
databaseCreator.accept(database, databaseName);
return databaseName;
}

@Override
void afterTestMethod(String databaseName) {
if (databaseName != null) {
database.dropDatabase(databaseName);
}
}
}

class PerClass implements DatabaseState {
static class PerClass extends DatabaseState {
private final Map<Class<?>, String> perClassDatabases = new ConcurrentHashMap<>();

PerClass(DatabaseEngine database) {
super(database);
}

@Override
public String ensureDatabaseCreated(Class<?> testClass, Consumer<String> databaseCreator) {
String ensureDatabaseCreated(String databaseName, Class<?> testClass, BiConsumer<DatabaseEngine, String> databaseCreator) {
if (testClass == null) {
throw new IllegalStateException("Per-class database extension must be registered as a static test class field in order to use the datasource during test instance construction.");
}
if (databaseName != null && databaseName.equals(perClassDatabases.get(testClass))) {
return databaseName;
}
String perClassDatabaseName = perClassDatabases.get(testClass);
if (perClassDatabaseName == null) {
synchronized (perClassDatabases) {
perClassDatabaseName = perClassDatabases.get(testClass);
if (perClassDatabaseName == null) {
perClassDatabaseName = newDatabaseName();
databaseCreator.accept(perClassDatabaseName);
databaseCreator.accept(database, perClassDatabaseName);
perClassDatabases.put(testClass, perClassDatabaseName);
}
}
}
return perClassDatabaseName;
}

@Override
void afterTestMethod(String databaseName) {
}
}

class PerExecution implements DatabaseState {
static class PerExecution extends DatabaseState {
private final String perExecutionDatabaseName = newDatabaseName();
private volatile boolean perExecutionDatabaseCreated = false;

PerExecution(DatabaseEngine database) {
super(database);
}

@Override
public String ensureDatabaseCreated(Class<?> testClass, Consumer<String> databaseCreator) {
String ensureDatabaseCreated(String databaseName, Class<?> testClass, BiConsumer<DatabaseEngine, String> databaseCreator) {
if (perExecutionDatabaseName.equals(databaseName)) {
return databaseName;
}
if (!perExecutionDatabaseCreated) {
synchronized (perExecutionDatabaseName) {
if (!perExecutionDatabaseCreated) {
databaseCreator.accept(perExecutionDatabaseName);
databaseCreator.accept(database, perExecutionDatabaseName);
perExecutionDatabaseCreated = true;
}
}
}
return perExecutionDatabaseName;
}

@Override
void afterTestMethod(String databaseName) {
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.Extension;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.InvocationInterceptor;
import org.junit.jupiter.api.extension.ReflectiveInvocationContext;

import javax.sql.DataSource;
import java.lang.reflect.Constructor;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

/**
* Base class for concrete database test extension implementations.
* Encapsulates the core extension behavior that does not rely on backing database specifics.
*/
public abstract class DatabaseTestExtension implements Extension, BeforeEachCallback, AfterEachCallback {
public abstract class DatabaseTestExtension implements Extension, BeforeEachCallback, AfterEachCallback, InvocationInterceptor {

/**
* Extension execution mode. Defines the database state guarantees for test executions.
Expand All @@ -29,44 +33,38 @@ public enum Mode {
/**
* Create a single database per JVM process. Any state written by tests will be visible to other tests.
*/
DATABASE_PER_EXECUTION;

private String ensureDatabaseCreated(TestDatabase database, Consumer<String> databaseCreator, Class<?> testClass) {
switch (this) {
case DATABASE_PER_TEST_METHOD:
return database.perMethod.ensureDatabaseCreated(testClass, databaseCreator);
case DATABASE_PER_TEST_CLASS:
return database.perClass.ensureDatabaseCreated(testClass, databaseCreator);
case DATABASE_PER_EXECUTION:
return database.perExecution.ensureDatabaseCreated(testClass, databaseCreator);
default:
throw new IllegalStateException("No strategy exists for " + this + " mode");
}
}
DATABASE_PER_EXECUTION
}

private final TestDatabase database;
private final Consumer<String> databaseCreator;
private final Mode mode;
private final DatabaseState databaseState;
private final BiConsumer<DatabaseEngine, String> databaseCreator;

private Class<?> testClass;
private String databaseName;

DatabaseTestExtension(TestDatabase database, Mode mode, boolean migrateOnce) {
this.database = database;
this.databaseCreator = makeDatabaseCreator(database, migrateOnce);
this.mode = mode;
this.databaseState = database.getState(mode);
this.databaseCreator = makeDatabaseCreator(migrateOnce, this::migrateDatabase);
}

@Override
public <T> T interceptTestClassConstructor(
Invocation<T> invocation,
ReflectiveInvocationContext<Constructor<T>> invocationContext,
ExtensionContext extensionContext
) throws Throwable {
this.testClass = extensionContext.getRequiredTestClass();
return invocation.proceed();
}

@Override
public void beforeEach(ExtensionContext context) {
this.databaseName = mode.ensureDatabaseCreated(database, databaseCreator, context.getRequiredTestClass());
this.testClass = context.getRequiredTestClass();
}

@Override
public void afterEach(ExtensionContext context) {
if (mode == Mode.DATABASE_PER_TEST_METHOD) {
database.dropDatabase(databaseName);
}
databaseState.afterTestMethod(databaseName);
}

/**
Expand All @@ -76,7 +74,8 @@ public void afterEach(ExtensionContext context) {
* @return dataSource for a migrated database
*/
public DataSource getDataSource() {
return database.dataSourceForDatabase(databaseName);
this.databaseName = databaseState.ensureDatabaseCreated(databaseName, testClass, databaseCreator);
return databaseState.dataSourceForDatabase(databaseName);
}

/**
Expand All @@ -87,16 +86,16 @@ public DataSource getDataSource() {
*/
abstract protected void migrateDatabase(DataSource dataSource);

private Consumer<String> makeDatabaseCreator(TestDatabase database, boolean migrateOnce) {
private static BiConsumer<DatabaseEngine, String> makeDatabaseCreator(boolean migrateOnce, Consumer<DataSource> migrator) {
if (migrateOnce) {
return databaseName -> {
database.ensureTemplateDatabaseMigrated(this::migrateDatabase);
return (database, databaseName) -> {
database.ensureTemplateDatabaseMigrated(migrator);
database.cloneTemplateDatabaseTo(databaseName);
};
} else {
return databaseName -> {
return (database, databaseName) -> {
database.createDatabase(databaseName);
migrateDatabase(database.dataSourceForDatabase(databaseName));
migrator.accept(database.dataSourceForDatabase(databaseName));
};
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,52 +1,26 @@
package io.github.rieske.dbtest.extension;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.function.Consumer;

abstract class TestDatabase {
final DatabaseState perMethod = new DatabaseState.PerMethod();
final DatabaseState perClass = new DatabaseState.PerClass();
final DatabaseState perExecution = new DatabaseState.PerExecution();

private volatile boolean templateDatabaseMigrated = false;

void ensureTemplateDatabaseMigrated(Consumer<DataSource> migrator) {
if (!templateDatabaseMigrated) {
synchronized (this) {
if (!templateDatabaseMigrated) {
migrateTemplateDatabase(migrator, dataSourceForDatabase(getTemplateDatabaseName()));
templateDatabaseMigrated = true;
}
}
}
class TestDatabase {
private final DatabaseState perMethod;
private final DatabaseState perClass;
private final DatabaseState perExecution;

TestDatabase(DatabaseEngine databaseEngine) {
this.perMethod = new DatabaseState.PerMethod(databaseEngine);
this.perClass = new DatabaseState.PerClass(databaseEngine);
this.perExecution = new DatabaseState.PerExecution(databaseEngine);
}

void createDatabase(String databaseName) {
executePrivileged("CREATE DATABASE " + databaseName);
}

void dropDatabase(String databaseName) {
executePrivileged("DROP DATABASE " + databaseName);
}

abstract void cloneTemplateDatabaseTo(String targetDatabaseName);

void executePrivileged(String sql) {
DataSource dataSource = getPrivilegedDataSource();
try (Connection conn = dataSource.getConnection()) {
conn.createStatement().execute(sql);
} catch (SQLException e) {
throw new RuntimeException(e);
DatabaseState getState(DatabaseTestExtension.Mode mode) {
switch (mode) {
case DATABASE_PER_TEST_METHOD:
return perMethod;
case DATABASE_PER_TEST_CLASS:
return perClass;
case DATABASE_PER_EXECUTION:
return perExecution;
default:
throw new IllegalStateException("No database state strategy exists for " + this + " mode");
}
}

abstract DataSource dataSourceForDatabase(String databaseName);

abstract String getTemplateDatabaseName();

abstract DataSource getPrivilegedDataSource();

abstract void migrateTemplateDatabase(Consumer<DataSource> migrator, DataSource templateDataSource);
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import static org.assertj.core.api.Assertions.assertThat;

abstract class DatabaseTest {
public abstract class DatabaseTest {
@RegisterExtension
private final DatabaseTestExtension database;

Expand Down
Loading

0 comments on commit cbe4641

Please sign in to comment.