diff --git a/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java b/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java index 9c815a1..f1e083d 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java +++ b/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java @@ -234,6 +234,7 @@ private RaftAgent(RaftConfiguration configuration, RaftListener raftListener) { timer, mapper, getSelfAsMember(raftClusterConfiguration.getSelf(), cluster), + configuration.getAllAddresses(), cluster, configuration.getConnectTimeout(), configuration.getMinReconnectInterval(), diff --git a/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java b/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java index 83fb2f4..7d695fd 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java +++ b/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java @@ -63,4 +63,10 @@ private RaftAgentConstants() { } // to protect instantiation * a successful connection to a Raft server. */ public static final int CONNECT_TIMEOUT = 5000; + + /** + * Indicates whether the local server's listener should listen + * on all addresses or only on the specified address + */ + public static final boolean ALL_ADDRESSES = false; } diff --git a/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java b/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java index 9ab2982..20e42fc 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java +++ b/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java @@ -73,6 +73,7 @@ public final class RaftConfiguration { private static final String SNAPSHOTS = "snapshots"; private static final String DATABASE = "database"; private static final String CLUSTER = "cluster"; + private static final String ALL_ADDRESSES = "allAddresses"; @Min(1) @Max(RaftConfigurationConstants.SIXTY_SECONDS) @@ -115,6 +116,10 @@ public final class RaftConfiguration { @NotNull @JsonProperty(ADDITIONAL_RECONNECT_INTERVAL_RANGE) private int additionalReconnectIntervalRange = RaftAgentConstants.ADDITIONAL_RECONNECT_INTERVAL_RANGE; + + @NotNull + @JsonProperty(ALL_ADDRESSES) + private boolean allAddresses = RaftAgentConstants.ALL_ADDRESSES; @JsonIgnore private final TimeUnit timeUnit = RaftConfigurationConstants.DEFAULT_TIME_UNIT; @@ -330,6 +335,16 @@ public void setMinReconnectInterval(int minReconnectInterval) { public int getAdditionalReconnectIntervalRange() { return additionalReconnectIntervalRange; } + + /** + * Get a flag indicating whether the local server should listen on all addresses + * or on the specified one only + * + * @return the all-addresses flag + */ + public boolean getAllAddresses() { + return allAddresses; + } /** * Set the maximum additional amount of time added to @@ -343,6 +358,16 @@ public int getAdditionalReconnectIntervalRange() { public void setAdditionalReconnectIntervalRange(int additionalReconnectIntervalRange) { this.additionalReconnectIntervalRange = additionalReconnectIntervalRange; } + + /** + * Set the flag indicating whether the local server listens on all addresses + * or not + * + * @param aa the all-addresses value + */ + public void setAllAddresses(boolean aa) { + this.allAddresses = aa; + } /** * Get the Raft database configuration. @@ -406,7 +431,8 @@ public boolean equals(@Nullable Object o) { && additionalReconnectIntervalRange == other.additionalReconnectIntervalRange && raftDatabaseConfiguration.equals(other.raftDatabaseConfiguration) && raftSnapshotsConfiguration.equals(other.raftSnapshotsConfiguration) - && raftClusterConfiguration.equals(other.raftClusterConfiguration); + && raftClusterConfiguration.equals(other.raftClusterConfiguration) + && allAddresses == other.allAddresses; } @Override @@ -422,7 +448,8 @@ public int hashCode() { additionalReconnectIntervalRange, raftDatabaseConfiguration, raftSnapshotsConfiguration, - raftClusterConfiguration + raftClusterConfiguration, + allAddresses ); } @@ -441,6 +468,7 @@ public String toString() { .add(DATABASE, raftDatabaseConfiguration) .add(SNAPSHOTS, raftSnapshotsConfiguration) .add(CLUSTER, raftClusterConfiguration) + .add(ALL_ADDRESSES, allAddresses) .toString(); } } diff --git a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java index 8d967ef..c8c88fc 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java +++ b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java @@ -34,6 +34,7 @@ import javax.annotation.Nullable; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -123,8 +124,9 @@ public synchronized final void initialize() throws StorageException { checkNotNull(statement); try { try { - addDatabaseCreateStatementsToBatch(statement); + addDatabaseCreateStatementsToBatch(statement, connection.getMetaData()); statement.executeBatch(); + initializeDatabase(connection); } finally { closeSilently(statement); } @@ -159,7 +161,8 @@ protected final synchronized boolean isInitialized() { return initialized; } - protected abstract void addDatabaseCreateStatementsToBatch(Statement statement) throws Exception; + protected abstract void addDatabaseCreateStatementsToBatch(Statement batchStatement, DatabaseMetaData metadata) throws Exception; + protected abstract void initializeDatabase(Connection connection) throws Exception; private void setupConnection() throws SQLException { if (connection == null) { diff --git a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java index 3b22cf4..44f7573 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java +++ b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java @@ -42,6 +42,7 @@ import java.io.IOException; import java.io.InputStream; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -54,7 +55,7 @@ /** * Implementation of {@code Log} that uses a JDBC backend. *
- * This implementation creates and uses a single table called {@code log_index} with the following structure: + * This implementation creates and uses a single table called {@code entries} with the following structure: *
* +-----------+-----------+----------+----------+
* | log_index | term | type | data |
@@ -128,11 +129,21 @@ public synchronized void setupCustomCommandSerializerAndDeserializer(CommandSeri
}
@Override
- protected void addDatabaseCreateStatementsToBatch(Statement statement) throws SQLException {
+ protected void addDatabaseCreateStatementsToBatch(Statement batchStatement, DatabaseMetaData metadata) throws SQLException {
LOGGER.info("setup raft log");
-
- statement.addBatch("CREATE TABLE IF NOT EXISTS entries(log_index BIGINT PRIMARY KEY, term BIGINT NOT NULL, type TINYINT NOT NULL, data BLOB DEFAULT NULL)");
- statement.addBatch("CREATE INDEX IF NOT EXISTS entries_index ON entries(log_index DESC)");
+
+ try (ResultSet rs = metadata.getTables(null, null, "ENTRIES", null)) {
+ if (! rs.next()) {
+ batchStatement.addBatch("CREATE TABLE entries(log_index BIGINT PRIMARY KEY, term BIGINT NOT NULL, "
+ + "type SMALLINT NOT NULL, data BLOB DEFAULT NULL)");
+ batchStatement.addBatch("CREATE INDEX entries_index ON entries(log_index DESC)");
+ }
+ }
+ }
+
+ @Override
+ protected void initializeDatabase(Connection connection) throws Exception {
+ // no special initialization required
}
@Override
diff --git a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java
index 924c0c7..3ef127e 100644
--- a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java
+++ b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java
@@ -33,6 +33,7 @@
import javax.annotation.Nullable;
import java.sql.Connection;
+import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
@@ -95,12 +96,32 @@ public JDBCStore(String url, String username, @Nullable String password) {
}
@Override
- protected void addDatabaseCreateStatementsToBatch(Statement statement) throws SQLException {
+ protected void addDatabaseCreateStatementsToBatch(Statement batchStatement, DatabaseMetaData metadata) throws SQLException {
LOGGER.info("setup raft store");
+ try (ResultSet rs = metadata.getTables(null, null, "CURRENT_TERM", null)) {
+ if (! rs.next()) {
+ batchStatement.addBatch("CREATE TABLE current_term(term BIGINT NOT NULL)");
+ batchStatement.addBatch("INSERT INTO current_term (term) VALUES (-1)");
+ }
+ }
+ try (ResultSet rs = metadata.getTables(null, null, "COMMIT_INDEX", null)) {
+ if (! rs.next()) {
+ batchStatement.addBatch("CREATE TABLE commit_index(commit_index BIGINT NOT NULL)");
+ batchStatement.addBatch("INSERT INTO commit_index (commit_index) VALUES (-1)");
+ }
+ }
+ try (ResultSet rs = metadata.getTables(null, null, "VOTED_FOR", null)) {
+ if (! rs.next()) {
+ batchStatement.addBatch("CREATE TABLE voted_for(term BIGINT NOT NULL, server VARCHAR(128) DEFAULT NULL)");
+ batchStatement.addBatch("INSERT INTO voted_for(term) VALUES (-1)");
+ }
+ }
- statement.addBatch("CREATE TABLE IF NOT EXISTS current_term(term BIGINT NOT NULL)");
- statement.addBatch("CREATE TABLE IF NOT EXISTS commit_index(commit_index BIGINT NOT NULL)");
- statement.addBatch("CREATE TABLE IF NOT EXISTS voted_for(term BIGINT NOT NULL, server VARCHAR(128) DEFAULT NULL)");
+ }
+
+ @Override
+ protected void initializeDatabase(Connection connection) throws SQLException {
+ // do nothing
}
private Long queryAndCheckConsistency(PreparedStatement statement, final String tableName) throws Exception {
@@ -118,12 +139,16 @@ public Long use(ResultSet resultSet) throws Exception {
@Override
public synchronized long getCurrentTerm() throws StorageException {
try {
- return executeQuery("SELECT term FROM current_term", new StatementWithReturnBlock() {
+ long rtn = executeQuery("SELECT term FROM current_term", new StatementWithReturnBlock() {
@Override
public @Nullable Long use(PreparedStatement statement) throws Exception {
return queryAndCheckConsistency(statement, "current_term");
}
});
+ if (rtn < 0) {
+ throw new RuntimeException("Current term not set");
+ }
+ return rtn;
} catch (Exception e) {
throw new StorageException("fail get currentTerm", e);
}
@@ -135,23 +160,6 @@ public synchronized void setCurrentTerm(final long term) throws StorageException
execute(new ConnectionBlock() {
@Override
public void use(Connection connection) throws Exception {
- boolean doUpdate = withStatement(connection, "SELECT COUNT(*) FROM current_term", new StatementWithReturnBlock() {
- @Override
- public Boolean use(PreparedStatement statement) throws Exception {
- return withResultSet(statement, new ResultSetBlock() {
- @Override
- public Boolean use(ResultSet resultSet) throws Exception {
- resultSet.next(); // COUNT(*) should always return a value
-
- int count = resultSet.getInt(1);
- checkState(count == 0 || count == 1, "current_term: too many rows:%s", count);
-
- return count == 1;
- }
- });
- }
- });
- if (doUpdate) {
withStatement(connection, "UPDATE current_term SET term=?", new StatementBlock() {
@Override
public void use(PreparedStatement statement) throws Exception {
@@ -160,16 +168,6 @@ public void use(PreparedStatement statement) throws Exception {
checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
}
});
- } else {
- withStatement(connection, "INSERT INTO current_term VALUES(?)", new StatementBlock() {
- @Override
- public void use(PreparedStatement statement) throws Exception {
- statement.setLong(1, term);
- int rowsUpdated = statement.executeUpdate();
- checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
- }
- });
- }
}
});
} catch (Exception e) {
@@ -180,12 +178,16 @@ public void use(PreparedStatement statement) throws Exception {
@Override
public synchronized long getCommitIndex() throws StorageException {
try {
- return executeQuery("SELECT commit_index FROM commit_index", new StatementWithReturnBlock() {
+ long rtn = executeQuery("SELECT commit_index FROM commit_index", new StatementWithReturnBlock() {
@Override
public @Nullable Long use(PreparedStatement statement) throws Exception {
return queryAndCheckConsistency(statement, "commit_index");
}
});
+ if (rtn < 0) {
+ throw new RuntimeException("Commit index not set");
+ }
+ return rtn;
} catch (Exception e) {
throw new StorageException("fail get commitIndex", e);
}
@@ -197,23 +199,6 @@ public synchronized void setCommitIndex(final long logIndex) throws StorageExcep
execute(new ConnectionBlock() {
@Override
public void use(Connection connection) throws Exception {
- boolean doUpdate = withStatement(connection, "SELECT COUNT(*) FROM commit_index", new StatementWithReturnBlock() {
- @Override
- public Boolean use(PreparedStatement statement) throws Exception {
- return withResultSet(statement, new ResultSetBlock() {
- @Override
- public Boolean use(ResultSet resultSet) throws Exception {
- resultSet.next(); // COUNT(*) should always return a value
-
- int count = resultSet.getInt(1);
- checkState(count == 0 || count == 1, "commit_index: too many rows:%s", count);
-
- return count == 1;
- }
- });
- }
- });
- if (doUpdate) {
withStatement(connection, "UPDATE commit_index SET commit_index=?", new StatementBlock() {
@Override
public void use(PreparedStatement statement) throws Exception {
@@ -222,16 +207,6 @@ public void use(PreparedStatement statement) throws Exception {
checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
}
});
- } else {
- withStatement(connection, "INSERT INTO commit_index VALUES(?)", new StatementBlock() {
- @Override
- public void use(PreparedStatement statement) throws Exception {
- statement.setLong(1, logIndex);
- int rowsUpdated = statement.executeUpdate();
- checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
- }
- });
- }
}
});
} catch (Exception e) {
diff --git a/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java b/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java
index eb6d25f..ba8ae5f 100644
--- a/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java
+++ b/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java
@@ -121,6 +121,7 @@ public final class RaftNetworkClient implements RPCSender {
private final int minReconnectInterval;
private final int additionalReconnectIntervalRange;
private final TimeUnit timeUnit;
+ private final boolean listenOnAllAddresses;
private volatile boolean running; // set during start/stop and accessed by netty-I/O and RaftNetworkClient caller threads
@@ -135,6 +136,7 @@ public final class RaftNetworkClient implements RPCSender {
* @param timer instance of {@code Timer} used to schedule network timeouts
* @param mapper instance of {@code ObjectMapper} used to generate JSON representations of {@link RaftRPC} messages
* @param self unique id of the local Raft server
+ * @param listenOnAllAddresses 'true' if the server socket ignores the address and binds to all
* @param cluster set of unique ids - one for each Raft server in the cluster
* @param connectTimeout maximum time {@code RaftNetworkClient} waits to establish a connection to another Raft server
* @param minReconnectInterval minimum amount of time to wait before reconnecting to a Raft server
@@ -147,6 +149,7 @@ public RaftNetworkClient(
Timer timer,
ObjectMapper mapper,
RaftMember self,
+ boolean listenOnAllAddresses,
Set cluster,
int connectTimeout,
int minReconnectInterval,
@@ -161,6 +164,7 @@ public RaftNetworkClient(
this.timer = timer;
this.mapper = mapper;
this.self = self;
+ this.listenOnAllAddresses = listenOnAllAddresses;
this.connectTimeout = connectTimeout;
this.minReconnectInterval = minReconnectInterval;
this.additionalReconnectIntervalRange = additionalReconnectIntervalRange;
@@ -274,7 +278,9 @@ private SocketAddress getResolvedBindAddress() { // FIXME (AG): find a way to re
if (bindAddress instanceof InetSocketAddress) {
InetSocketAddress inetBindAddress = (InetSocketAddress) bindAddress;
- if (inetBindAddress.isUnresolved()) {
+ if (listenOnAllAddresses) {
+ bindAddress = new InetSocketAddress(inetBindAddress.getPort());
+ } else if (inetBindAddress.isUnresolved()) {
bindAddress = new InetSocketAddress(inetBindAddress.getHostName(), inetBindAddress.getPort());
}
}
diff --git a/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java b/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java
index 010811d..749a26e 100644
--- a/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java
+++ b/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java
@@ -30,6 +30,7 @@
import org.skife.jdbi.v2.ResultIterator;
import org.skife.jdbi.v2.StatementContext;
+import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
import org.skife.jdbi.v2.sqlobject.Bind;
import org.skife.jdbi.v2.sqlobject.BindBean;
import org.skife.jdbi.v2.sqlobject.SqlQuery;
@@ -68,8 +69,12 @@ abstract class SnapshotsDAO {
*/
@Transaction
void createSnapshotsTableWithIndex() {
- createSnapshotsTable();
- createTimestampIndexForSnapshotsTable();
+ try {
+ createSnapshotsTable();
+ createTimestampIndexForSnapshotsTable();
+ } catch(UnableToExecuteStatementException e) {
+ // nada
+ }
}
/**
@@ -77,13 +82,13 @@ void createSnapshotsTableWithIndex() {
*/
// FIXME (AG): I'm essentially using the timestamp as a primary key
// I did this is because auto-increment indices have different syntaxes in different dbs - might be best to create an explicit index
- @SqlUpdate("create table if not exists snapshots(filename varchar(255) not null, ts bigint unique, last_term bigint not null, last_index bigint not null)")
+ @SqlUpdate("create table snapshots(filename varchar(255) not null, ts bigint unique, last_term bigint not null, last_index bigint not null)")
abstract void createSnapshotsTable();
/**
* Create the index for the {@code snapshots} table.
*/
- @SqlUpdate("create index if not exists ts_index on snapshots(ts)")
+ @SqlUpdate("create index ts_index on snapshots(ts)")
abstract void createTimestampIndexForSnapshotsTable();
/**