MultiTenantConnectionProviderImpl.java
package com.tradecloud.repository.multitenant;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.log4j.Logger;
import org.flywaydb.core.Flyway;
import org.flywaydb.core.api.output.*;
import org.hibernate.engine.jdbc.connections.spi.AbstractDataSourceBasedMultiTenantConnectionProviderImpl;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextClosedEvent;
import org.springframework.stereotype.Component;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
/**
* MultiTenantConnectionProviderImpl with concurrency safety,
* explicit cleanup, and migration protection.
*/
@Component("multiTenantConnectionProvider")
public class MultiTenantConnectionProviderImpl extends AbstractDataSourceBasedMultiTenantConnectionProviderImpl
implements ApplicationListener<ContextClosedEvent> {
private static final Logger log = Logger.getLogger(MultiTenantConnectionProviderImpl.class);
private static final long serialVersionUID = 6241633589847209550L;
private String serverActiveClients = System.getenv("active_clients");
private String maxConnectionAge = System.getenv("max_connection_age");
@Autowired
private AwsSecretManagerConfig awsSecretManagerConfig;
private final ConcurrentHashMap<String, HikariDataSource> registry = new ConcurrentHashMap<>();
public DataSource getDefaultDataSource() {
String tenantIdentifier = "tradecloud_default";
int max = 5;
int min = 1;
String jdbcUrl = "jdbc:postgresql://"
+ StringUtils.defaultString(awsSecretManagerConfig.getDatabaseHost(), "db")
+ ":5432/"
+ tenantIdentifier;
String userName = StringUtils.defaultString(awsSecretManagerConfig.getDatabaseUsername(), "ds");
String password = StringUtils.defaultString(awsSecretManagerConfig.getDatabasePassword(), "netscape");
HikariDataSource pds = registry.get(tenantIdentifier);
if (pds == null) {
return createDataSource(tenantIdentifier, max, min, 1_800_000, jdbcUrl, userName, password);
}
return pds;
}
@Override
protected DataSource selectAnyDataSource() {
return getDefaultDataSource();
}
@Override
protected DataSource selectDataSource(String tenantIdentifier) {
int max = 50;
int min = 5;
String jdbcUrl = "jdbc:postgresql://"
+ StringUtils.defaultString(awsSecretManagerConfig.getDatabaseHost(), "db")
+ ":5432/"
+ tenantIdentifier;
String userName = StringUtils.defaultString(awsSecretManagerConfig.getDatabaseUsername(), "ds");
String password = StringUtils.defaultString(awsSecretManagerConfig.getDatabasePassword(), "netscape");
HikariDataSource pds = registry.get(tenantIdentifier);
if (pds == null) {
// Optional: Limit number of pools (protect DB from overload)
/*
if (registry.size() >= MAX_POOLS) {
log.error("Pool limit reached! Refusing to create new pool for: " + tenantIdentifier);
throw new IllegalStateException("Max pool count reached.");
}
*/
if (!checkDBExists(tenantIdentifier)) {
return getDefaultDataSource();
}
Integer maxAge = null;
if (maxConnectionAge != null && !maxConnectionAge.trim().isEmpty()) {
maxAge = Integer.valueOf(maxConnectionAge);
log.info(String.format("max age configured: %s", maxConnectionAge));
}
return createDataSource(tenantIdentifier, max, min, ObjectUtils.firstNonNull(maxAge, 1_800_000), jdbcUrl, userName, password);
}
return pds;
}
private HikariDataSource createDataSource(String tenantIdentifier, int max, int min, int maxConnectionAge, String jdbcUrl,
String userName, String password) {
HikariConfig config = new HikariConfig();
config.setJdbcUrl(jdbcUrl);
config.setUsername(userName);
config.setPassword(password);
config.setDriverClassName("org.postgresql.Driver");
config.setMaximumPoolSize(max);
config.setIdleTimeout(300_000); // 5 minutes
config.setConnectionTimeout(30_000); // 30 seconds
config.setPoolName("tradecloud_" + tenantIdentifier);
config.setMinimumIdle(min);
config.setMaxLifetime(maxConnectionAge);
config.setConnectionTestQuery("SELECT 1;");
config.setLeakDetectionThreshold(60_000);
config.setValidationTimeout(5000);
config.setKeepaliveTime(600000);
config.setAutoCommit(true);
HikariDataSource pds = new HikariDataSource(config);
registry.put(tenantIdentifier, pds);
migration(pds);
return pds;
}
private void migration(HikariDataSource cpds) {
String skipMigrationEnv = System.getenv("skipMigration_env");
log.debug("skipMigrationEnv2: " + skipMigrationEnv);
String skipMigration = Optional.ofNullable(skipMigrationEnv)
.filter(env -> !env.isBlank())
.or(() -> Optional.ofNullable(System.getProperty("skipMigration")))
.orElse(System.getProperty("skipMigration"));
log.debug("skipMigrationEnv3: " + skipMigrationEnv);
if (serverActiveClients != null && (skipMigration == null || !skipMigration.equals("true"))) {
String baselineVersion = "0";
String baselineConfig = System.getenv("baselineVersion_" + cpds.getPoolName());
if (baselineConfig != null && !baselineConfig.trim().isEmpty()) {
baselineVersion = baselineConfig;
}
Flyway flyway = Flyway.configure().outOfOrder(true).placeholderReplacement(false)
.dataSource(cpds).baselineVersion(baselineVersion)
.load();
log.debug("MIGRATING " + cpds.getPoolName());
BaselineResult baseline = flyway.baseline();
log.debug("#########################baseline#################################");
log.debug("successfullyBaselined: " + baseline.successfullyBaselined);
log.debug("database: " + baseline.database);
for (String warning : baseline.warnings)
log.debug("warning: " + warning);
log.debug("#########################repair#################################");
RepairResult repair = flyway.repair();
if (repair != null) {
logRepairsDone(repair);
}
log.debug("#########################migration#################################");
MigrateResult migrate = flyway.migrate();
log.debug("success: " + migrate.database);
for (String warning : migrate.warnings)
log.debug("warning: " + warning);
for (MigrateOutput output : migrate.migrations) {
log.debug("category: " + output.category);
log.debug("type: " + output.type);
log.debug("description: " + output.description);
log.debug("executionTime: " + output.executionTime);
log.debug("version: " + output.version);
}
log.debug("##########################################################");
} else {
log.debug("####################### active_clients not set, skipping flyway migration ###################################");
}
}
public static void logRepairsDone(RepairResult repair) {
log.debug("database:" + repair.database);
if (!repair.migrationsAligned.isEmpty()) {
log.debug("migrationsAligned:");
for (RepairOutput repairOutput : repair.migrationsAligned) {
log.debug("version:" + repairOutput.version);
log.debug("description:" + repairOutput.description);
log.debug("filepath:" + repairOutput.filepath);
}
}
if (!repair.migrationsDeleted.isEmpty()) {
log.debug("migrationsDeleted:");
for (RepairOutput repairOutput : repair.migrationsDeleted) {
log.debug("version:" + repairOutput.version);
log.debug("description:" + repairOutput.description);
log.debug("filepath:" + repairOutput.filepath);
}
}
if (!repair.migrationsRemoved.isEmpty()) {
log.debug("migrationsRemoved:");
for (RepairOutput repairOutput : repair.migrationsRemoved) {
log.debug("version:" + repairOutput.version);
log.debug("description:" + repairOutput.description);
log.debug("filepath:" + repairOutput.filepath);
}
}
}
private boolean checkDBExists(String tenantIdentifier) {
Connection connection = null;
try {
Class.forName("org.postgresql.Driver");
String jdbcUrl = "jdbc:postgresql://"
+ StringUtils.defaultString(awsSecretManagerConfig.getDatabaseHost(), "db")
+ ":5432/"
+ tenantIdentifier;
String username = StringUtils.defaultString(awsSecretManagerConfig.getDatabaseUsername(), "ds");
String password = StringUtils.defaultString(awsSecretManagerConfig.getDatabasePassword(), "netscape");
connection = DriverManager.getConnection(jdbcUrl, username, password);
} catch (Exception sqlException) {
return false;
} finally {
try {
if (connection != null && !connection.isClosed()) {
connection.close();
}
} catch (SQLException e) {
e.printStackTrace();
}
}
return true;
}
@Override
public Connection getConnection(String tenantIdentifier) throws SQLException {
try {
return super.getConnection(tenantIdentifier);
} catch (SQLException e) {
log.error("Error getting connection for tenant: " + tenantIdentifier, e);
// Pool recovery logic as before
if ("An attempt by a client to checkout a Connection has timed out.".equals(e.getLocalizedMessage())
|| e.getLocalizedMessage().contains("ResourcePoolException")) {
System.setProperty("skipMigration", "true");
HikariDataSource pds = registry.get(tenantIdentifier);
if (pds != null) {
pds.close();
registry.remove(tenantIdentifier);
}
selectDataSource(tenantIdentifier);
}
throw e;
}
}
@Override
public void onApplicationEvent(ContextClosedEvent contextClosedEvent) {
for (HikariDataSource dataSource : registry.values()) {
try {
dataSource.close();
} catch (Exception e) {
log.error("Failed to close pool for tenant", e);
}
}
registry.clear();
}
}