diff --git a/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java b/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java index 9c815a1..f1e083d 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java +++ b/libraft-agent/src/main/java/io/libraft/agent/RaftAgent.java @@ -234,6 +234,7 @@ private RaftAgent(RaftConfiguration configuration, RaftListener raftListener) { timer, mapper, getSelfAsMember(raftClusterConfiguration.getSelf(), cluster), + configuration.getAllAddresses(), cluster, configuration.getConnectTimeout(), configuration.getMinReconnectInterval(), diff --git a/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java b/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java index 83fb2f4..7d695fd 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java +++ b/libraft-agent/src/main/java/io/libraft/agent/RaftAgentConstants.java @@ -63,4 +63,10 @@ private RaftAgentConstants() { } // to protect instantiation * a successful connection to a Raft server. */ public static final int CONNECT_TIMEOUT = 5000; + + /** + * Indicates whether the local server's listener should listen + * on all addresses or only on the specified address + */ + public static final boolean ALL_ADDRESSES = false; } diff --git a/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java b/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java index 9ab2982..20e42fc 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java +++ b/libraft-agent/src/main/java/io/libraft/agent/configuration/RaftConfiguration.java @@ -73,6 +73,7 @@ public final class RaftConfiguration { private static final String SNAPSHOTS = "snapshots"; private static final String DATABASE = "database"; private static final String CLUSTER = "cluster"; + private static final String ALL_ADDRESSES = "allAddresses"; @Min(1) @Max(RaftConfigurationConstants.SIXTY_SECONDS) @@ -115,6 +116,10 @@ public final class RaftConfiguration { @NotNull @JsonProperty(ADDITIONAL_RECONNECT_INTERVAL_RANGE) private int additionalReconnectIntervalRange = RaftAgentConstants.ADDITIONAL_RECONNECT_INTERVAL_RANGE; + + @NotNull + @JsonProperty(ALL_ADDRESSES) + private boolean allAddresses = RaftAgentConstants.ALL_ADDRESSES; @JsonIgnore private final TimeUnit timeUnit = RaftConfigurationConstants.DEFAULT_TIME_UNIT; @@ -330,6 +335,16 @@ public void setMinReconnectInterval(int minReconnectInterval) { public int getAdditionalReconnectIntervalRange() { return additionalReconnectIntervalRange; } + + /** + * Get a flag indicating whether the local server should listen on all addresses + * or on the specified one only + * + * @return the all-addresses flag + */ + public boolean getAllAddresses() { + return allAddresses; + } /** * Set the maximum additional amount of time added to @@ -343,6 +358,16 @@ public int getAdditionalReconnectIntervalRange() { public void setAdditionalReconnectIntervalRange(int additionalReconnectIntervalRange) { this.additionalReconnectIntervalRange = additionalReconnectIntervalRange; } + + /** + * Set the flag indicating whether the local server listens on all addresses + * or not + * + * @param aa the all-addresses value + */ + public void setAllAddresses(boolean aa) { + this.allAddresses = aa; + } /** * Get the Raft database configuration. @@ -406,7 +431,8 @@ public boolean equals(@Nullable Object o) { && additionalReconnectIntervalRange == other.additionalReconnectIntervalRange && raftDatabaseConfiguration.equals(other.raftDatabaseConfiguration) && raftSnapshotsConfiguration.equals(other.raftSnapshotsConfiguration) - && raftClusterConfiguration.equals(other.raftClusterConfiguration); + && raftClusterConfiguration.equals(other.raftClusterConfiguration) + && allAddresses == other.allAddresses; } @Override @@ -422,7 +448,8 @@ public int hashCode() { additionalReconnectIntervalRange, raftDatabaseConfiguration, raftSnapshotsConfiguration, - raftClusterConfiguration + raftClusterConfiguration, + allAddresses ); } @@ -441,6 +468,7 @@ public String toString() { .add(DATABASE, raftDatabaseConfiguration) .add(SNAPSHOTS, raftSnapshotsConfiguration) .add(CLUSTER, raftClusterConfiguration) + .add(ALL_ADDRESSES, allAddresses) .toString(); } } diff --git a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java index 8d967ef..c8c88fc 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java +++ b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCBase.java @@ -34,6 +34,7 @@ import javax.annotation.Nullable; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; @@ -123,8 +124,9 @@ public synchronized final void initialize() throws StorageException { checkNotNull(statement); try { try { - addDatabaseCreateStatementsToBatch(statement); + addDatabaseCreateStatementsToBatch(statement, connection.getMetaData()); statement.executeBatch(); + initializeDatabase(connection); } finally { closeSilently(statement); } @@ -159,7 +161,8 @@ protected final synchronized boolean isInitialized() { return initialized; } - protected abstract void addDatabaseCreateStatementsToBatch(Statement statement) throws Exception; + protected abstract void addDatabaseCreateStatementsToBatch(Statement batchStatement, DatabaseMetaData metadata) throws Exception; + protected abstract void initializeDatabase(Connection connection) throws Exception; private void setupConnection() throws SQLException { if (connection == null) { diff --git a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java index 3b22cf4..44f7573 100644 --- a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java +++ b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCLog.java @@ -42,6 +42,7 @@ import java.io.IOException; import java.io.InputStream; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -54,7 +55,7 @@ /** * Implementation of {@code Log} that uses a JDBC backend. *

- * This implementation creates and uses a single table called {@code log_index} with the following structure: + * This implementation creates and uses a single table called {@code entries} with the following structure: *

  * +-----------+-----------+----------+----------+
  * | log_index |   term    |   type   |   data   |
@@ -128,11 +129,21 @@ public synchronized void setupCustomCommandSerializerAndDeserializer(CommandSeri
     }
 
     @Override
-    protected void addDatabaseCreateStatementsToBatch(Statement statement) throws SQLException {
+    protected void addDatabaseCreateStatementsToBatch(Statement batchStatement, DatabaseMetaData metadata) throws SQLException {
         LOGGER.info("setup raft log");
-
-        statement.addBatch("CREATE TABLE IF NOT EXISTS entries(log_index BIGINT PRIMARY KEY, term BIGINT NOT NULL, type TINYINT NOT NULL, data BLOB DEFAULT NULL)");
-        statement.addBatch("CREATE INDEX IF NOT EXISTS entries_index ON entries(log_index DESC)");
+        
+        try (ResultSet rs = metadata.getTables(null, null, "ENTRIES", null)) {
+        	if (! rs.next()) {
+        		batchStatement.addBatch("CREATE TABLE entries(log_index BIGINT PRIMARY KEY, term BIGINT NOT NULL, "
+                		+ "type SMALLINT NOT NULL, data BLOB DEFAULT NULL)");
+        		batchStatement.addBatch("CREATE INDEX entries_index ON entries(log_index DESC)");
+        	}
+        }
+    }
+    
+    @Override
+    protected void initializeDatabase(Connection connection) throws Exception {
+    	// no special initialization required
     }
 
     @Override
diff --git a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java
index 924c0c7..3ef127e 100644
--- a/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java
+++ b/libraft-agent/src/main/java/io/libraft/agent/persistence/JDBCStore.java
@@ -33,6 +33,7 @@
 
 import javax.annotation.Nullable;
 import java.sql.Connection;
+import java.sql.DatabaseMetaData;
 import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.sql.SQLException;
@@ -95,12 +96,32 @@ public JDBCStore(String url, String username, @Nullable String password) {
     }
 
     @Override
-    protected void addDatabaseCreateStatementsToBatch(Statement statement) throws SQLException {
+    protected void addDatabaseCreateStatementsToBatch(Statement batchStatement, DatabaseMetaData metadata) throws SQLException {
         LOGGER.info("setup raft store");
+        try (ResultSet rs = metadata.getTables(null, null, "CURRENT_TERM", null)) {
+        	if (! rs.next()) {
+                batchStatement.addBatch("CREATE TABLE current_term(term BIGINT NOT NULL)");
+                batchStatement.addBatch("INSERT INTO current_term (term) VALUES (-1)");
+        	}
+        }
+        try (ResultSet rs = metadata.getTables(null, null, "COMMIT_INDEX", null)) {
+        	if (! rs.next()) {
+                batchStatement.addBatch("CREATE TABLE commit_index(commit_index BIGINT NOT NULL)");
+                batchStatement.addBatch("INSERT INTO commit_index (commit_index) VALUES (-1)");
+        	}
+        }
+        try (ResultSet rs = metadata.getTables(null, null, "VOTED_FOR", null)) {
+        	if (! rs.next()) {
+                batchStatement.addBatch("CREATE TABLE voted_for(term BIGINT NOT NULL, server VARCHAR(128) DEFAULT NULL)");
+                batchStatement.addBatch("INSERT INTO voted_for(term) VALUES (-1)");
+        	}
+        }
 
-        statement.addBatch("CREATE TABLE IF NOT EXISTS current_term(term BIGINT NOT NULL)");
-        statement.addBatch("CREATE TABLE IF NOT EXISTS commit_index(commit_index BIGINT NOT NULL)");
-        statement.addBatch("CREATE TABLE IF NOT EXISTS voted_for(term BIGINT NOT NULL, server VARCHAR(128) DEFAULT NULL)");
+    }
+    
+    @Override
+    protected void initializeDatabase(Connection connection) throws SQLException {
+    	// do nothing
     }
 
     private Long queryAndCheckConsistency(PreparedStatement statement, final String tableName) throws Exception {
@@ -118,12 +139,16 @@ public Long use(ResultSet resultSet) throws Exception {
     @Override
     public synchronized long getCurrentTerm() throws StorageException {
         try {
-            return executeQuery("SELECT term FROM current_term", new StatementWithReturnBlock() {
+            long rtn = executeQuery("SELECT term FROM current_term", new StatementWithReturnBlock() {
                 @Override
                 public @Nullable Long use(PreparedStatement statement) throws Exception {
                     return queryAndCheckConsistency(statement, "current_term");
                 }
             });
+            if (rtn < 0) {
+            	throw new RuntimeException("Current term not set");
+            }
+            return rtn;
         } catch (Exception e) {
             throw new StorageException("fail get currentTerm", e);
         }
@@ -135,23 +160,6 @@ public synchronized void setCurrentTerm(final long term) throws StorageException
             execute(new ConnectionBlock() {
                 @Override
                 public void use(Connection connection) throws Exception {
-                    boolean doUpdate = withStatement(connection, "SELECT COUNT(*) FROM current_term", new StatementWithReturnBlock() {
-                        @Override
-                        public Boolean use(PreparedStatement statement) throws Exception {
-                            return withResultSet(statement, new ResultSetBlock() {
-                                @Override
-                                public Boolean use(ResultSet resultSet) throws Exception {
-                                    resultSet.next(); // COUNT(*) should always return a value
-
-                                    int count = resultSet.getInt(1);
-                                    checkState(count == 0 || count == 1, "current_term: too many rows:%s", count);
-
-                                    return count == 1;
-                                }
-                            });
-                        }
-                    });
-                    if (doUpdate) {
                         withStatement(connection, "UPDATE current_term SET term=?", new StatementBlock() {
                             @Override
                             public void use(PreparedStatement statement) throws Exception {
@@ -160,16 +168,6 @@ public void use(PreparedStatement statement) throws Exception {
                                 checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
                             }
                         });
-                    } else {
-                        withStatement(connection, "INSERT INTO current_term VALUES(?)", new StatementBlock() {
-                            @Override
-                            public void use(PreparedStatement statement) throws Exception {
-                                statement.setLong(1, term);
-                                int rowsUpdated = statement.executeUpdate();
-                                checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
-                            }
-                        });
-                    }
                 }
             });
         } catch (Exception e) {
@@ -180,12 +178,16 @@ public void use(PreparedStatement statement) throws Exception {
     @Override
     public synchronized long getCommitIndex() throws StorageException {
         try {
-            return executeQuery("SELECT commit_index FROM commit_index", new StatementWithReturnBlock() {
+            long rtn = executeQuery("SELECT commit_index FROM commit_index", new StatementWithReturnBlock() {
                 @Override
                 public @Nullable Long use(PreparedStatement statement) throws Exception {
                     return queryAndCheckConsistency(statement, "commit_index");
                 }
             });
+            if (rtn < 0) {
+            	throw new RuntimeException("Commit index not set");
+            }
+            return rtn;
         } catch (Exception e) {
             throw new StorageException("fail get commitIndex", e);
         }
@@ -197,23 +199,6 @@ public synchronized void setCommitIndex(final long logIndex) throws StorageExcep
             execute(new ConnectionBlock() {
                 @Override
                 public void use(Connection connection) throws Exception {
-                    boolean doUpdate = withStatement(connection, "SELECT COUNT(*) FROM commit_index", new StatementWithReturnBlock() {
-                        @Override
-                        public Boolean use(PreparedStatement statement) throws Exception {
-                            return withResultSet(statement, new ResultSetBlock() {
-                                @Override
-                                public Boolean use(ResultSet resultSet) throws Exception {
-                                    resultSet.next(); // COUNT(*) should always return a value
-
-                                    int count = resultSet.getInt(1);
-                                    checkState(count == 0 || count == 1, "commit_index: too many rows:%s", count);
-
-                                    return count == 1;
-                                }
-                            });
-                        }
-                    });
-                    if (doUpdate) {
                         withStatement(connection, "UPDATE commit_index SET commit_index=?", new StatementBlock() {
                             @Override
                             public void use(PreparedStatement statement) throws Exception {
@@ -222,16 +207,6 @@ public void use(PreparedStatement statement) throws Exception {
                                 checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
                             }
                         });
-                    } else {
-                        withStatement(connection, "INSERT INTO commit_index VALUES(?)", new StatementBlock() {
-                            @Override
-                            public void use(PreparedStatement statement) throws Exception {
-                                statement.setLong(1, logIndex);
-                                int rowsUpdated = statement.executeUpdate();
-                                checkState(rowsUpdated == 1, "commit_index: too many rows:%s)", rowsUpdated);
-                            }
-                        });
-                    }
                 }
             });
         } catch (Exception e) {
diff --git a/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java b/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java
index eb6d25f..ba8ae5f 100644
--- a/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java
+++ b/libraft-agent/src/main/java/io/libraft/agent/rpc/RaftNetworkClient.java
@@ -121,6 +121,7 @@ public final class RaftNetworkClient implements RPCSender {
     private final int minReconnectInterval;
     private final int additionalReconnectIntervalRange;
     private final TimeUnit timeUnit;
+    private final boolean listenOnAllAddresses;
 
     private volatile boolean running; // set during start/stop and accessed by netty-I/O and RaftNetworkClient caller threads
 
@@ -135,6 +136,7 @@ public final class RaftNetworkClient implements RPCSender {
      * @param timer instance of {@code Timer} used to schedule network timeouts
      * @param mapper instance of {@code ObjectMapper} used to generate JSON representations of {@link RaftRPC} messages
      * @param self unique id of the local Raft server
+     * @param listenOnAllAddresses 'true' if the server socket ignores the address and binds to all
      * @param cluster set of unique ids - one for each Raft server in the cluster
      * @param connectTimeout maximum time {@code RaftNetworkClient} waits to establish a connection to another Raft server
      * @param minReconnectInterval minimum amount of time to wait before reconnecting to a Raft server
@@ -147,6 +149,7 @@ public RaftNetworkClient(
             Timer timer,
             ObjectMapper mapper,
             RaftMember self,
+            boolean listenOnAllAddresses,
             Set cluster,
             int connectTimeout,
             int minReconnectInterval,
@@ -161,6 +164,7 @@ public RaftNetworkClient(
         this.timer = timer;
         this.mapper = mapper;
         this.self = self;
+        this.listenOnAllAddresses = listenOnAllAddresses;
         this.connectTimeout = connectTimeout;
         this.minReconnectInterval = minReconnectInterval;
         this.additionalReconnectIntervalRange = additionalReconnectIntervalRange;
@@ -274,7 +278,9 @@ private SocketAddress getResolvedBindAddress() { // FIXME (AG): find a way to re
 
         if (bindAddress instanceof InetSocketAddress) {
             InetSocketAddress inetBindAddress = (InetSocketAddress) bindAddress;
-            if (inetBindAddress.isUnresolved()) {
+            if (listenOnAllAddresses) {
+            	bindAddress = new InetSocketAddress(inetBindAddress.getPort());
+            } else if (inetBindAddress.isUnresolved()) {
                 bindAddress = new InetSocketAddress(inetBindAddress.getHostName(), inetBindAddress.getPort());
             }
         }
diff --git a/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java b/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java
index 010811d..749a26e 100644
--- a/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java
+++ b/libraft-agent/src/main/java/io/libraft/agent/snapshots/SnapshotsDAO.java
@@ -30,6 +30,7 @@
 
 import org.skife.jdbi.v2.ResultIterator;
 import org.skife.jdbi.v2.StatementContext;
+import org.skife.jdbi.v2.exceptions.UnableToExecuteStatementException;
 import org.skife.jdbi.v2.sqlobject.Bind;
 import org.skife.jdbi.v2.sqlobject.BindBean;
 import org.skife.jdbi.v2.sqlobject.SqlQuery;
@@ -68,8 +69,12 @@ abstract class SnapshotsDAO {
      */
     @Transaction
     void createSnapshotsTableWithIndex() {
-        createSnapshotsTable();
-        createTimestampIndexForSnapshotsTable();
+    	try {
+          createSnapshotsTable();
+          createTimestampIndexForSnapshotsTable();
+    	} catch(UnableToExecuteStatementException e) {
+    		// nada
+    	}
     }
 
     /**
@@ -77,13 +82,13 @@ void createSnapshotsTableWithIndex() {
      */
     // FIXME (AG): I'm essentially using the timestamp as a primary key
     // I did this is because auto-increment indices have different syntaxes in different dbs - might be best to create an explicit index
-    @SqlUpdate("create table if not exists snapshots(filename varchar(255) not null, ts bigint unique, last_term bigint not null, last_index bigint not null)")
+    @SqlUpdate("create table snapshots(filename varchar(255) not null, ts bigint unique, last_term bigint not null, last_index bigint not null)")
     abstract void createSnapshotsTable();
 
     /**
      * Create the index for the {@code snapshots} table.
      */
-    @SqlUpdate("create index if not exists ts_index on snapshots(ts)")
+    @SqlUpdate("create index ts_index on snapshots(ts)")
     abstract void createTimestampIndexForSnapshotsTable();
 
     /**